Skip to content

Fix one-sided padding with sharded tensors.#1535

Open
nloppi wants to merge 2 commits intoNVIDIA:mainfrom
nloppi:main
Open

Fix one-sided padding with sharded tensors.#1535
nloppi wants to merge 2 commits intoNVIDIA:mainfrom
nloppi:main

Conversation

@nloppi
Copy link
Copy Markdown

@nloppi nloppi commented Mar 25, 2026

PhysicsNeMo Pull Request

Description

A tiny PR that fixes a bug that occurs when we do one-sided padding with sharded tensors or pad sharded tensors with uneven chunks. Found it as I was working on a differentiable solver with sharded tensors, where the partitions at the domain edges need to also contain extra ghost/halo elements for boundary conditions.

To reproduce

from torch.distributed.tensor.placement_types import Shard
from torch.distributed.tensor import distribute_module

from physicsnemo.distributed import DistributedManager
from physicsnemo.domain_parallel import scatter_tensor, ShardTensor

class test_pad(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = torch.nn.functional.pad(x, (0, 0, 2, 0), mode='replicate')
        x = torch.nn.functional.pad(x, (0, 0, 0, 2), mode='constant', value=1.0)
        return x

DistributedManager.initialize()
dm = DistributedManager()
nparts = 4
dmesh = dm.initialize_mesh(
    mesh_shape=(nparts,),
    mesh_dim_names=["domain"],
)

x = torch.ones(1, 1, 80, 10).to(dm.device)
shard_x = scatter_tensor(x, 0, dmesh, (Shard(dim=2),))

if dm.rank in [0, 3]:
    y = torch.ones(1,1,22,10).to(dm.device)
else:
    y = torch.ones(1,1,20,10).to(dm.device)

shard_y = ShardTensor.from_local(
    y, dmesh, (Shard(dim=2),), sharding_shapes="infer"
)

dist_test_pad = distribute_module(test_pad().to(dm.device), device_mesh=dmesh)

x_pad = dist_test_pad(shard_x)
out = shard_y + x_pad

print("full_tensor",  out.full_tensor())

and run with torchrun --nproc-per-node 4 test.py

Checklist

Dependencies

N/A

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@nloppi nloppi requested a review from coreyjadams as a code owner March 25, 2026 19:49
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 25, 2026

Greptile Summary

This PR fixes a bug in generic_pad_nd_wrapper where the output shape metadata was computed using local_input.shape (the current rank's local tensor shape) instead of local_shape (the per-shard shape from spec.sharding_shapes()) when iterating over all mesh positions to build the output_shapes dict.

Root cause: The outer loop in generic_pad_nd_wrapper iterates over all shard positions to build the complete output_shapes dict (used as sharding_shapes metadata for the output ShardTensor). For each position i, local_shape holds the actual shape of the shard at that position. The old code passed local_input.shape — always the current rank's local tensor — causing every non-current-rank entry in output_shapes to reflect the wrong shape. This made different ranks compute inconsistent, contradictory metadata for the output ShardTensor.

When does it trigger? The mismatch only becomes observable when the sharding is uneven across ranks, which happens naturally after a one-sided padding operation (e.g., only the leftmost rank gains elements). A subsequent pad operation on such an unevenly-sharded tensor would record wrong sharding_shapes, breaking any downstream operation that relies on those shapes (e.g., tensor addition with a correctly-shaped ShardTensor).

Fix assessment: The one-line change is correct and minimal. local_shape is exactly the right value to pass because it comes from spec.sharding_shapes(), which is consistent and authoritative across all ranks.

Notable checklist items:

  • The "New or existing tests cover these changes" box is unchecked. A regression test reproducing the scenario from the PR description would prevent this class of bug from regressing.
  • The CHANGELOG.md entry is also absent.

Important Files Changed

Filename Overview
physicsnemo/domain_parallel/shard_utils/padding.py One-line fix replacing local_input.shape with local_shape in the sharding shapes loop, ensuring each rank's shape metadata is derived from the iterated shard rather than the current rank's local tensor; fix is correct but lacks new test coverage.

Reviews (1): Last reviewed commit: "Fix one-sided padding with sharded tenso..." | Re-trigger Greptile

@peterdsharpe
Copy link
Copy Markdown
Collaborator

/blossom-ci

@coreyjadams
Copy link
Copy Markdown
Collaborator

Hi @nloppi thanks for opening this! From what I can see it looks absolutely reasonable. I will take a look. It's awesome to see it being used!

Is this urgent? We are heading in to a long weekend, and I need to manually test this still before merge.

Meanwhile, do you think you could add your test case / reproducer to the test suite? I think it could go here:
https://github.com/NVIDIA/physicsnemo/blob/main/test/domain_parallel/ops/test_padding.py

With the way you wrote it out with functionals, it's probably easiest to mimic the way we're testing attention:
https://github.com/NVIDIA/physicsnemo/blob/main/test/domain_parallel/ops/test_sdpa.py

If you want some guidance on running the multi gpu tests, you can find it here:
https://github.com/NVIDIA/physicsnemo/blob/main/test/plugins/distributed_fixtures.py

Sorry, we have this CI running internally post-merge so we catch issues that crop up but it's not set up for PRs yet. It's on the to-do list ...

@nloppi
Copy link
Copy Markdown
Author

nloppi commented Mar 27, 2026

Hi @coreyjadams! Thanks!! I'll gladly add the test, and thanks for the tips. There's no rush. We can catch up next week :)

@nloppi
Copy link
Copy Markdown
Author

nloppi commented Apr 9, 2026

Hi @coreyjadams! I added the test. I think I need to do it this way. If I just use the last assert comparing the global shapes, the sharded_image_pad.full_tensor() just hangs. Now it can raise a shape mismatch for the local shards. Please do suggest clean-ups as you best see fit :)

@nloppi
Copy link
Copy Markdown
Author

nloppi commented Apr 14, 2026

I still need to rerun the test on a bigger system (after the new clean-up commit).

Signed-off-by: Niki Loppi <nloppi@nvidia.com>
@nloppi
Copy link
Copy Markdown
Author

nloppi commented Apr 15, 2026

Hi @coreyjadams! I think this should be more or less ready now. I added the pre-commit clean-ups, validated that the new test works on an 8 GPU system, rebased and squashed everything under a single commit. Many thanks for your help!

@coreyjadams
Copy link
Copy Markdown
Collaborator

Thanks @nloppi ! FYI I think we need to make your PR dependent on this one to merge first:
#1556

It shouldn't be an issue, I just want to make sure we test your fixes with the final bug-fixes for ShardTensor going into the release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants