Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/fixtures/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class TiedWeightsModel(nn.Module):
"""Reproduces #322/#377: a single parameter tensor shared across two distinct
modules (weight tying, as in tied embeddings / lm_head). The shared tensor
must be counted once, matching `sum(p.numel() for p in model.parameters())`.
"""

def __init__(self, vocab_size: int = 1000, embed_dim: int = 64) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.head = nn.Linear(embed_dim, vocab_size, bias=False)
self.head.weight = self.embedding.weight # tie the weights

def forward(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, self.head(self.embedding(x)))


class PrunedLayerNameModel(nn.Module):
"""Model that defines parameters with _orig and _mask as suffixes."""

Expand Down
10 changes: 5 additions & 5 deletions tests/test_output/flan_t5_small.out
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ T5ForConditionalGeneration [3, 100, 512]
│ └─T5LayerNorm: 2-4 [3, 100, 512] 512
│ └─Dropout: 2-5 [3, 100, 512] --
├─T5Stack: 1-2 -- --
│ └─Embedding: 2-6 [3, 100, 512] 16,449,536
│ └─Embedding: 2-6 [3, 100, 512] (recursive)
│ └─Dropout: 2-7 [3, 100, 512] --
│ └─ModuleList: 2-8 -- --
│ │ └─T5Block: 3-9 [3, 100, 512] 3,147,456
Expand All @@ -32,13 +32,13 @@ T5ForConditionalGeneration [3, 100, 512]
│ └─Dropout: 2-10 [3, 100, 512] --
├─Linear: 1-3 [3, 100, 32128] 16,449,536
==============================================================================================================
Total params: 93,410,688
Trainable params: 93,410,688
Total params: 76,961,152
Trainable params: 76,961,152
Non-trainable params: 0
Total mult-adds (G): 18.25
==============================================================================================================
Input size (kB): 7.37
Forward/backward pass size (MB): 326.28
Params size (MB): 373.64
Estimated Total Size (MB): 699.93
Params size (MB): 307.84
Estimated Total Size (MB): 634.14
==============================================================================================================
16 changes: 16 additions & 0 deletions tests/torchinfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
SimpleRNN,
SingleInputNet,
TensorKernelSizeConv,
TiedWeightsModel,
UninitializedParameterModel,
)
from torchinfo import ColumnSettings, summary
Expand Down Expand Up @@ -659,6 +660,21 @@ def test_shared_module_in_nested_list() -> None:
assert result.trainable_params == 120


def test_tied_weights() -> None:
# Regression test for #322/#377: a single parameter tensor shared across two
# distinct modules (weight tying) was counted once per module, overestimating
# the total. It must be counted once, matching model.parameters().
model = TiedWeightsModel()
result = summary(
model, input_data=torch.randint(0, 1000, (2, 8)), depth=2, verbose=0
)

expected = sum(p.numel() for p in model.parameters())
assert expected == 64_000 # tied weight counted once, not 128_000
assert result.total_params == expected
assert result.trainable_params == expected


def test_hide_recursive_layers() -> None:
model = SimpleRNN()
summary(model, input_size=(2, 3))
Expand Down
28 changes: 28 additions & 0 deletions torchinfo/layer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,34 @@ def calculate_num_params(self) -> None:
f"└─{self.inner_layers[final_name][ColumnSettings.NUM_PARAMS][2:]}"
)

def get_direct_param_ids(self) -> list[tuple[int, int, int, bool]]:
"""
Return (param_tensor_id, num_params, param_bytes, is_trainable) for each
parameter owned *directly* by this module (recurse=False), applying the
same mask/lazy handling as calculate_num_params.

Used to compute deduplicated global totals: a single tensor referenced by
more than one module (weight tying, e.g. tied embeddings / lm_head) is
counted once because every reference yields the same tensor id. Summing
the direct params of every module is equivalent to the previous
leaf + leftover scheme for models without sharing, but stays correct when
a tensor is shared across distinct modules.
"""
contributions = []
for name, param in self.module.named_parameters(recurse=False):
if is_lazy(param):
continue # type: ignore[unreachable]
cur_params, _ = self.get_param_count(self.module, name, param)
contributions.append(
(
id(param),
cur_params,
param.element_size() * cur_params,
param.requires_grad,
)
)
return contributions

def calculate_macs(self) -> None:
"""
Set MACs using the module's parameters and layer's output size, which is
Expand Down
40 changes: 25 additions & 15 deletions torchinfo/model_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,38 @@ def __init__(
self.total_params, self.trainable_params = 0, 0
self.total_param_bytes, self.total_output_bytes = 0, 0

# TODO: Figure out why the below functions using max() are ever 0
# (they should always be non-negative), and remove the call to max().
# Investigation: https://github.com/TylerYep/torchinfo/pull/195
for layer_info in summary_list:
if layer_info.is_leaf_layer:
self.total_mult_adds += layer_info.macs
if layer_info.num_params > 0:
# x2 for gradients
self.total_output_bytes += layer_info.output_bytes * 2
if layer_info.is_recursive:
continue
self.total_params += max(layer_info.num_params, 0)

# Parameter totals are taken from the root module(s). A module's
# named_parameters() already deduplicates tensors shared across modules
# (weight tying, e.g. tied embeddings / lm_head) and counts parameters of
# submodules that weren't executed in this forward pass -- so this matches
# `sum(p.numel() for p in model.parameters())`. Summing the per-row counts
# instead would double-count tied weights (#322/#377).
for layer_info in summary_list:
if layer_info.parent_info is None:
self.total_params += layer_info.num_params
self.trainable_params += layer_info.trainable_params
self.total_param_bytes += layer_info.param_bytes
self.trainable_params += max(layer_info.trainable_params, 0)
else:
if layer_info.is_recursive:
continue
leftover_params = layer_info.leftover_params()
leftover_trainable_params = layer_info.leftover_trainable_params()
self.total_params += max(leftover_params, 0)
self.trainable_params += max(leftover_trainable_params, 0)
self.total_param_bytes += max(layer_info.leftover_param_bytes(), 0)

# Mark a module as "(recursive)" when every parameter it owns directly was
# already counted by an earlier row (a fully shared/tied module), so the
# displayed per-row counts still sum to the deduplicated total.
seen_param_ids: set[int] = set()
for layer_info in summary_list:
direct_params = layer_info.get_direct_param_ids()
contributed_new = False
for param_id, _, _, _ in direct_params:
if param_id not in seen_param_ids:
seen_param_ids.add(param_id)
contributed_new = True
if not contributed_new and any(count for _, count, _, _ in direct_params):
layer_info.is_recursive = True
self.formatting.set_layer_name_width(summary_list)

def __repr__(self) -> str:
Expand Down
Loading