diff --git a/distconv/distconv.py b/distconv/distconv.py index 7e9ed6b..59818e2 100644 --- a/distconv/distconv.py +++ b/distconv/distconv.py @@ -701,6 +701,11 @@ def _handle_circular_pad(cls, func, args, kwargs): Handle circular padding for DCTensor by applying normal padding to non-shard dimensions and marking the shard dimension for circular handling during conv operations. + For the most common use-case of uniform padding, this is called recursively. + First with pad_list = (1,1)*ndims, which will then mark the shard dimension padding as 0 + and call the padding func again. We do this instead of operating on the torch tensor directly + to ensure the memory is freed and no graph breaks. + Args: func (Callable): The F.pad function. args (Tuple): The arguments to F.pad. @@ -711,10 +716,14 @@ def _handle_circular_pad(cls, func, args, kwargs): """ input_tensor = args[0] pad = args[1] if len(args) > 1 else kwargs.get("pad") + pad_list = list(pad) parallel_strategy = input_tensor._parallel_strategy shard_dim = parallel_strategy.shard_dim - pad_list = list(pad) + # If no padding in the sharded dimensions, then we do the padding using the standard op + if all([True if pad_list[::-2][s - 2] == 0 else False for s in shard_dim]): + return super().__torch_function__(func, (DCTensor,), args, kwargs) + shard_padding = [ 0, ] * len(shard_dim) @@ -747,11 +756,10 @@ def _handle_circular_pad(cls, func, args, kwargs): is_periodic[i] = False # Call F.pad with modified padding (shard dim padding disabled) - new_args = (_ToTensor.apply(input_tensor), tuple(pad_list)) + args[2:] - partial_padded_tensor = func(*new_args, **kwargs) - - # Create result DCTensor with periodic flag and stored shard padding - result: DCTensor = _FromTensor.apply(partial_padded_tensor, parallel_strategy) + new_args = (input_tensor, tuple(pad_list)) + args[2:] + # Re-run padding function on DCTensor with periodic flag and stored shard padding. + # This will not actually apply the padding, but rather call __torch_function__. + result = func(*new_args, **kwargs) result._is_periodic = tuple(is_periodic) result._periodic_shard_padding = tuple(shard_padding)