From a9c8e935c05805c6fd493bd900f2e6be2e8a5550 Mon Sep 17 00:00:00 2001 From: vietlh Date: Sun, 7 Jun 2026 16:44:02 +0700 Subject: [PATCH 1/2] fix: count parameters shared across modules once (#322, #377) torchinfo built total_params by summing each row's num_params, with no deduplication across parameter tensors. When one tensor is referenced by multiple distinct modules (weight tying -- tied embeddings / lm_head, shared projection heads, etc.) it was counted once per referencing module, overestimating the total (e.g. flan-t5-small reported 93,410,688 vs the true 76,961,152). This differs from the module-instance sharing fixed in #327: tied tensors live on different module objects, so id(module)-based recursion detection doesn't catch them. Take parameter totals from the root module instead. A module's named_parameters() already deduplicates shared tensors (remove_duplicate defaults to True) and includes submodules not run in the forward pass, so this matches `sum(p.numel() for p in model.parameters())`. A module whose parameters were all already counted by an earlier row is marked "(recursive)" so the per-row counts still sum to the total. Add TiedWeightsModel + test_tied_weights, and regenerate the flan-t5 snapshot (Python 3.14). Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/fixtures/models.py | 16 ++++++++++++ tests/test_output/flan_t5_small.out | 10 ++++---- tests/torchinfo_test.py | 16 ++++++++++++ torchinfo/layer_info.py | 28 ++++++++++++++++++++ torchinfo/model_statistics.py | 40 ++++++++++++++++++----------- 5 files changed, 90 insertions(+), 20 deletions(-) diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index e9c5f9c8..3d508235 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -700,6 +700,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class TiedWeightsModel(nn.Module): + """Reproduces #322/#377: a single parameter tensor shared across two distinct + modules (weight tying, as in tied embeddings / lm_head). The shared tensor + must be counted once, matching `sum(p.numel() for p in model.parameters())`. + """ + + def __init__(self, vocab_size: int = 1000, embed_dim: int = 64) -> None: + super().__init__() + self.embedding = nn.Embedding(vocab_size, embed_dim) + self.head = nn.Linear(embed_dim, vocab_size, bias=False) + self.head.weight = self.embedding.weight # tie the weights + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.head(self.embedding(x)) + + class PrunedLayerNameModel(nn.Module): """Model that defines parameters with _orig and _mask as suffixes.""" diff --git a/tests/test_output/flan_t5_small.out b/tests/test_output/flan_t5_small.out index e5541ddb..d63dd957 100644 --- a/tests/test_output/flan_t5_small.out +++ b/tests/test_output/flan_t5_small.out @@ -17,7 +17,7 @@ T5ForConditionalGeneration [3, 100, 512] │ └─T5LayerNorm: 2-4 [3, 100, 512] 512 │ └─Dropout: 2-5 [3, 100, 512] -- ├─T5Stack: 1-2 -- -- -│ └─Embedding: 2-6 [3, 100, 512] 16,449,536 +│ └─Embedding: 2-6 [3, 100, 512] (recursive) │ └─Dropout: 2-7 [3, 100, 512] -- │ └─ModuleList: 2-8 -- -- │ │ └─T5Block: 3-9 [3, 100, 512] 3,147,456 @@ -32,13 +32,13 @@ T5ForConditionalGeneration [3, 100, 512] │ └─Dropout: 2-10 [3, 100, 512] -- ├─Linear: 1-3 [3, 100, 32128] 16,449,536 ============================================================================================================== -Total params: 93,410,688 -Trainable params: 93,410,688 +Total params: 76,961,152 +Trainable params: 76,961,152 Non-trainable params: 0 Total mult-adds (G): 18.25 ============================================================================================================== Input size (kB): 7.37 Forward/backward pass size (MB): 326.28 -Params size (MB): 373.64 -Estimated Total Size (MB): 699.93 +Params size (MB): 307.84 +Estimated Total Size (MB): 634.14 ============================================================================================================== diff --git a/tests/torchinfo_test.py b/tests/torchinfo_test.py index 5f63e219..f4d87d98 100644 --- a/tests/torchinfo_test.py +++ b/tests/torchinfo_test.py @@ -44,6 +44,7 @@ SimpleRNN, SingleInputNet, TensorKernelSizeConv, + TiedWeightsModel, UninitializedParameterModel, ) from torchinfo import ColumnSettings, summary @@ -657,6 +658,21 @@ def test_shared_module_in_nested_list() -> None: assert result.trainable_params == 120 +def test_tied_weights() -> None: + # Regression test for #322/#377: a single parameter tensor shared across two + # distinct modules (weight tying) was counted once per module, overestimating + # the total. It must be counted once, matching model.parameters(). + model = TiedWeightsModel() + result = summary( + model, input_data=torch.randint(0, 1000, (2, 8)), depth=2, verbose=0 + ) + + expected = sum(p.numel() for p in model.parameters()) + assert expected == 64_000 # tied weight counted once, not 128_000 + assert result.total_params == expected + assert result.trainable_params == expected + + def test_hide_recursive_layers() -> None: model = SimpleRNN() summary(model, input_size=(2, 3)) diff --git a/torchinfo/layer_info.py b/torchinfo/layer_info.py index 44ecdb05..e3721f58 100644 --- a/torchinfo/layer_info.py +++ b/torchinfo/layer_info.py @@ -229,6 +229,34 @@ def calculate_num_params(self) -> None: f"└─{self.inner_layers[final_name][ColumnSettings.NUM_PARAMS][2:]}" ) + def get_direct_param_ids(self) -> list[tuple[int, int, int, bool]]: + """ + Return (param_tensor_id, num_params, param_bytes, is_trainable) for each + parameter owned *directly* by this module (recurse=False), applying the + same mask/lazy handling as calculate_num_params. + + Used to compute deduplicated global totals: a single tensor referenced by + more than one module (weight tying, e.g. tied embeddings / lm_head) is + counted once because every reference yields the same tensor id. Summing + the direct params of every module is equivalent to the previous + leaf + leftover scheme for models without sharing, but stays correct when + a tensor is shared across distinct modules. + """ + contributions = [] + for name, param in self.module.named_parameters(recurse=False): + if is_lazy(param): + continue # type: ignore[unreachable] + cur_params, _ = self.get_param_count(self.module, name, param) + contributions.append( + ( + id(param), + cur_params, + param.element_size() * cur_params, + param.requires_grad, + ) + ) + return contributions + def calculate_macs(self) -> None: """ Set MACs using the module's parameters and layer's output size, which is diff --git a/torchinfo/model_statistics.py b/torchinfo/model_statistics.py index de1d5a1f..b7be4806 100644 --- a/torchinfo/model_statistics.py +++ b/torchinfo/model_statistics.py @@ -27,28 +27,38 @@ def __init__( self.total_params, self.trainable_params = 0, 0 self.total_param_bytes, self.total_output_bytes = 0, 0 - # TODO: Figure out why the below functions using max() are ever 0 - # (they should always be non-negative), and remove the call to max(). - # Investigation: https://github.com/TylerYep/torchinfo/pull/195 for layer_info in summary_list: if layer_info.is_leaf_layer: self.total_mult_adds += layer_info.macs if layer_info.num_params > 0: # x2 for gradients self.total_output_bytes += layer_info.output_bytes * 2 - if layer_info.is_recursive: - continue - self.total_params += max(layer_info.num_params, 0) + + # Parameter totals are taken from the root module(s). A module's + # named_parameters() already deduplicates tensors shared across modules + # (weight tying, e.g. tied embeddings / lm_head) and counts parameters of + # submodules that weren't executed in this forward pass -- so this matches + # `sum(p.numel() for p in model.parameters())`. Summing the per-row counts + # instead would double-count tied weights (#322/#377). + for layer_info in summary_list: + if layer_info.parent_info is None: + self.total_params += layer_info.num_params + self.trainable_params += layer_info.trainable_params self.total_param_bytes += layer_info.param_bytes - self.trainable_params += max(layer_info.trainable_params, 0) - else: - if layer_info.is_recursive: - continue - leftover_params = layer_info.leftover_params() - leftover_trainable_params = layer_info.leftover_trainable_params() - self.total_params += max(leftover_params, 0) - self.trainable_params += max(leftover_trainable_params, 0) - self.total_param_bytes += max(layer_info.leftover_param_bytes(), 0) + + # Mark a module as "(recursive)" when every parameter it owns directly was + # already counted by an earlier row (a fully shared/tied module), so the + # displayed per-row counts still sum to the deduplicated total. + seen_param_ids: set[int] = set() + for layer_info in summary_list: + direct_params = layer_info.get_direct_param_ids() + contributed_new = False + for param_id, _, _, _ in direct_params: + if param_id not in seen_param_ids: + seen_param_ids.add(param_id) + contributed_new = True + if not contributed_new and any(count for _, count, _, _ in direct_params): + layer_info.is_recursive = True self.formatting.set_layer_name_width(summary_list) def __repr__(self) -> str: From c8683b31bc0ce98eeef0e707981f48cd7a9dd42a Mon Sep 17 00:00:00 2001 From: vietlh Date: Sun, 7 Jun 2026 17:17:58 +0700 Subject: [PATCH 2/2] fix: cast TiedWeightsModel.forward return to satisfy mypy no-any-return Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/fixtures/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index 3d508235..762aacb9 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -713,7 +713,7 @@ def __init__(self, vocab_size: int = 1000, embed_dim: int = 64) -> None: self.head.weight = self.embedding.weight # tie the weights def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.head(self.embedding(x)) + return cast(torch.Tensor, self.head(self.embedding(x))) class PrunedLayerNameModel(nn.Module):