Skip to content

Conversation

@pstjohn
Copy link
Contributor

@pstjohn pstjohn commented Dec 30, 2025

Adds tests to ensure that the parameter values after reset_parameters() match their initial distributions.

Fixes #2528, #2529

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 30, 2025

Greptile Summary

Adds test test_reset_parameters_doesnt_change_parameter_stats to verify that calling reset_parameters() on TE modules maintains the same initialization statistics (mean and std) for all parameters.

Key changes:

  • Parameterized test across all core modules (LayerNorm, RMSNorm, Linear, LayerNormLinear, LayerNormMLP)
  • Compares parameter statistics before and after reset_parameters() call with tight tolerances (atol=1e-3, rtol=1e-3)
  • Removes extraneous blank line in class definition

Test methodology concern:
The test compares statistics of two independent random samples from the same distribution, which works well for constant-initialized parameters (layer norm weights, biases) but may be fragile for randomly-initialized weight matrices. The test will correctly catch the bugs described in issues #2528 (layer_norm_weight should be 1.0) and #2529 (bias should be 0.0), since those have std=0 initially and would show std>0 if incorrectly reinitialized. However, for weight matrices with random initialization, comparing statistics of two different random samples could theoretically cause false positives, though large parameter counts make this unlikely in practice.

Confidence Score: 3/5

  • This PR is moderately safe to merge - it adds test coverage for known initialization bugs without changing production code
  • The PR adds test coverage for issues LayerNormLinear reset_parameters() leads to the wrong initialization. #2528 and Linear layer reset_parameters() changes bias zero init to random init near 0 #2529, which is valuable for preventing regressions once those bugs are fixed. However, the test methodology has a conceptual issue: it compares statistics of independent random samples, which could theoretically lead to false positives. The test will correctly catch the constant-initialization bugs (layer_norm_weight and bias), but the approach is fragile for randomly-initialized parameters. Given that parameter counts are large (millions of elements), false positives are unlikely in practice, but the test design could be more robust.
  • tests/pytorch/test_deferred_init.py requires attention - consider revising test methodology to handle random vs constant initialization separately

Important Files Changed

Filename Overview
tests/pytorch/test_deferred_init.py Added test to verify reset_parameters() preserves parameter statistics, but test logic may have false positive issues for randomly-initialized parameters

Sequence Diagram

sequenceDiagram
    participant Test as test_reset_parameters_doesnt_change_parameter_stats
    participant Module as TE Module (Linear/LayerNormLinear/etc)
    participant Base as TransformerEngineBaseModule
    
    Test->>Module: Initialize module with device='cuda'
    Module->>Module: Create parameters with init functions
    Note over Module: layer_norm_weight: init_method_constant(1.0)<br/>bias: init_method_constant(0.0)<br/>weight: Xavier/Kaiming
    
    Test->>Module: Capture initial statistics
    Module-->>Test: {param: {mean, std}}
    
    Test->>Module: reset_parameters()
    Module->>Base: super().reset_parameters()
    Base->>Base: For each parameter
    Base->>Base: Get init_fn from param_init_meta
    Base->>Base: Call init_fn(param)
    Note over Base: Should use SAME init function<br/>registered during __init__
    Base-->>Module: Parameters reinitialized
    Module->>Module: Set parallelism attributes
    Module-->>Test: Done
    
    Test->>Module: Capture statistics after reset
    Module-->>Test: {param: {mean, std}}
    
    Test->>Test: Compare statistics with tolerances
    Note over Test: torch.testing.assert_close<br/>atol=1e-3, rtol=1e-3
    
    alt Statistics match
        Test->>Test: Test passes ✓
    else Statistics differ
        Test->>Test: Test fails (Bug detected) ✗
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. tests/pytorch/test_deferred_init.py, line 110-123 (link)

    syntax: The msg parameter uses a lambda function, but torch.testing.assert_close expects a string, not a callable. This will cause the error message to display <lambda> instead of the actual error details.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

ksivaman
ksivaman previously approved these changes Jan 2, 2026
Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Member

ksivaman commented Jan 2, 2026

/te-ci pytorch L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. tests/pytorch/test_deferred_init.py, line 96-99 (link)

    logic: This test logic has a subtle issue: it compares statistics of two independent random samples from the same distribution, expecting them to be nearly identical (atol=1e-3, rtol=1e-3).

    For constant-initialized parameters (layer_norm_weight=1.0, bias=0.0): mean and std should be identical before/after, so this works.

    For randomly-initialized parameters (weight matrices): the test compares statistics of two different random samples. Even with large parameter counts (~2M elements), there's sampling variability that could cause spurious failures.

    Consider instead:

    • For constant params: check exact values match expected constants
    • For random params: either use much looser tolerances or check initialization function directly
  2. tests/pytorch/test_deferred_init.py, line 110-123 (link)

    style: The tolerances atol=1e-3, rtol=1e-3 are quite tight. For weight parameters initialized with random distributions (Xavier/Kaiming), two independent samples will have different means and standard deviations due to sampling variability.

    For example, with Xavier uniform on a 1024×2048 matrix:

    • Standard error of mean ≈ σ/√n ≈ 0.00001
    • This is well within tolerance, but it's fragile

    For smaller parameter tensors (e.g., bias vectors with 2048 elements), the sampling variability increases. If these tests start failing intermittently, consider loosening tolerances or restructuring the test to check initialization functions directly rather than statistics of samples.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LayerNormLinear reset_parameters() leads to the wrong initialization.

2 participants