Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions distconv/distconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,11 @@ def _handle_circular_pad(cls, func, args, kwargs):
Handle circular padding for DCTensor by applying normal padding to non-shard dimensions
and marking the shard dimension for circular handling during conv operations.

For the most common use-case of uniform padding, this is called recursively.
First with pad_list = (1,1)*ndims, which will then mark the shard dimension padding as 0
and call the padding func again. We do this instead of operating on the torch tensor directly
to ensure the memory is freed and no graph breaks.

Args:
func (Callable): The F.pad function.
args (Tuple): The arguments to F.pad.
Expand All @@ -711,10 +716,14 @@ def _handle_circular_pad(cls, func, args, kwargs):
"""
input_tensor = args[0]
pad = args[1] if len(args) > 1 else kwargs.get("pad")
pad_list = list(pad)

parallel_strategy = input_tensor._parallel_strategy
shard_dim = parallel_strategy.shard_dim
pad_list = list(pad)
# If no padding in the sharded dimensions, then we do the padding using the standard op
if all([True if pad_list[::-2][s - 2] == 0 else False for s in shard_dim]):
return super().__torch_function__(func, (DCTensor,), args, kwargs)

shard_padding = [
0,
] * len(shard_dim)
Expand Down Expand Up @@ -747,11 +756,10 @@ def _handle_circular_pad(cls, func, args, kwargs):
is_periodic[i] = False

# Call F.pad with modified padding (shard dim padding disabled)
new_args = (_ToTensor.apply(input_tensor), tuple(pad_list)) + args[2:]
partial_padded_tensor = func(*new_args, **kwargs)

# Create result DCTensor with periodic flag and stored shard padding
result: DCTensor = _FromTensor.apply(partial_padded_tensor, parallel_strategy)
new_args = (input_tensor, tuple(pad_list)) + args[2:]
# Re-run padding function on DCTensor with periodic flag and stored shard padding.
# This will not actually apply the padding, but rather call __torch_function__.
result = func(*new_args, **kwargs)
result._is_periodic = tuple(is_periodic)
result._periodic_shard_padding = tuple(shard_padding)

Expand Down