-
Notifications
You must be signed in to change notification settings - Fork 596
Add tests that reset_parameters doesn't change parameter initial value ranges #2550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Greptile SummaryAdds test Key changes:
Test methodology concern: Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Test as test_reset_parameters_doesnt_change_parameter_stats
participant Module as TE Module (Linear/LayerNormLinear/etc)
participant Base as TransformerEngineBaseModule
Test->>Module: Initialize module with device='cuda'
Module->>Module: Create parameters with init functions
Note over Module: layer_norm_weight: init_method_constant(1.0)<br/>bias: init_method_constant(0.0)<br/>weight: Xavier/Kaiming
Test->>Module: Capture initial statistics
Module-->>Test: {param: {mean, std}}
Test->>Module: reset_parameters()
Module->>Base: super().reset_parameters()
Base->>Base: For each parameter
Base->>Base: Get init_fn from param_init_meta
Base->>Base: Call init_fn(param)
Note over Base: Should use SAME init function<br/>registered during __init__
Base-->>Module: Parameters reinitialized
Module->>Module: Set parallelism attributes
Module-->>Test: Done
Test->>Module: Capture statistics after reset
Module-->>Test: {param: {mean, std}}
Test->>Test: Compare statistics with tolerances
Note over Test: torch.testing.assert_close<br/>atol=1e-3, rtol=1e-3
alt Statistics match
Test->>Test: Test passes ✓
else Statistics differ
Test->>Test: Test fails (Bug detected) ✗
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
tests/pytorch/test_deferred_init.py, line 110-123 (link)syntax: The
msgparameter uses a lambda function, buttorch.testing.assert_closeexpects a string, not a callable. This will cause the error message to display<lambda>instead of the actual error details.
1 file reviewed, 1 comment
ksivaman
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci pytorch L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
tests/pytorch/test_deferred_init.py, line 96-99 (link)logic: This test logic has a subtle issue: it compares statistics of two independent random samples from the same distribution, expecting them to be nearly identical (atol=1e-3, rtol=1e-3).
For constant-initialized parameters (
layer_norm_weight=1.0,bias=0.0): mean and std should be identical before/after, so this works.For randomly-initialized parameters (weight matrices): the test compares statistics of two different random samples. Even with large parameter counts (~2M elements), there's sampling variability that could cause spurious failures.
Consider instead:
- For constant params: check exact values match expected constants
- For random params: either use much looser tolerances or check initialization function directly
-
tests/pytorch/test_deferred_init.py, line 110-123 (link)style: The tolerances atol=1e-3, rtol=1e-3 are quite tight. For weight parameters initialized with random distributions (Xavier/Kaiming), two independent samples will have different means and standard deviations due to sampling variability.
For example, with Xavier uniform on a 1024×2048 matrix:
- Standard error of mean ≈ σ/√n ≈ 0.00001
- This is well within tolerance, but it's fragile
For smaller parameter tensors (e.g., bias vectors with 2048 elements), the sampling variability increases. If these tests start failing intermittently, consider loosening tolerances or restructuring the test to check initialization functions directly rather than statistics of samples.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
1 file reviewed, 2 comments
Adds tests to ensure that the parameter values after
reset_parameters()match their initial distributions.Fixes #2528, #2529