From 67159e9e965b65dbad32b09ec64a48e84714a6b4 Mon Sep 17 00:00:00 2001 From: vietlh Date: Sun, 7 Jun 2026 11:28:18 +0700 Subject: [PATCH 1/4] fix: correct param counts for modules shared across parents (#327) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A module instance shared by several parents (e.g. one nn.ReLU() passed into every block, as in the reported VNet) was counted incorrectly, inflating the total parameters — especially when combined with nested ModuleLists. Two root causes, both stemming from a shared module having one parent recorded instead of many: 1. Hierarchy (torchinfo.py): the pre-hook captured (var_name, depth, parent_info) at registration time and kept only the last parent, so every execution of a shared module reported the wrong parent. This scrambled the layer tree and mis-grouped children. Fixed by resolving the parent dynamically at execution time: accumulate every structural context a module is reached through, maintain a runtime call stack via the pre/post hooks, and select the context whose nearest executing ancestor is the current stack top. Single-parent modules are unchanged. 2. Counting (layer_info.py): leftover_params() excluded recursive children from its subtraction, re-attributing a recursive child's params (already counted at their real occurrence) to the parent — counting a shared parameterized module once per parent. Fixed with a shared _leftover() helper that subtracts each distinct child once (keyed by layer_id) and skips recursive subtrees. Adds the SharedModuleInNestedList fixture and a regression test. Verified no behavioral change for existing models (RecursiveNet, ReuseReLU, ReuseLinear, SimpleRNN, etc. all produce identical output). Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/fixtures/models.py | 33 +++++++ .../shared_module_in_nested_list.out | 37 ++++++++ tests/torchinfo_test.py | 16 ++++ torchinfo/layer_info.py | 59 +++++++++---- torchinfo/torchinfo.py | 86 ++++++++++++++++--- 5 files changed, 199 insertions(+), 32 deletions(-) create mode 100644 tests/test_output/shared_module_in_nested_list.out diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index 4f6dc98..e9c5f9c 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -667,6 +667,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return cast(torch.Tensor, self.model(x)) +class SharedModuleInNestedList(nn.Module): + """Reproduces #327: a single module instance shared across several parents, + combined with nested ModuleLists. + + Each block holds its layers in a ModuleList and reuses one shared activation + that is also referenced by its sibling blocks. Recording a single parent for + the shared module corrupted the layer hierarchy and double-counted params. + """ + + class Block(nn.Module): + def __init__(self, act: nn.Module) -> None: + super().__init__() + self.layers = nn.ModuleList([nn.Linear(4, 4) for _ in range(2)]) + self.act = act # shared instance, passed in from the parent + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for layer in self.layers: + x = self.act(layer(x)) + return x + + def __init__(self) -> None: + super().__init__() + shared_act = nn.ReLU() + self.blocks = nn.ModuleList( + [SharedModuleInNestedList.Block(shared_act) for _ in range(3)] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + class PrunedLayerNameModel(nn.Module): """Model that defines parameters with _orig and _mask as suffixes.""" diff --git a/tests/test_output/shared_module_in_nested_list.out b/tests/test_output/shared_module_in_nested_list.out new file mode 100644 index 0000000..7038e7d --- /dev/null +++ b/tests/test_output/shared_module_in_nested_list.out @@ -0,0 +1,37 @@ +========================================================================================== +Layer (type:depth-idx) Output Shape Param # +========================================================================================== +SharedModuleInNestedList [1, 4] -- +├─ModuleList: 1-1 -- -- +│ └─Block: 2-1 [1, 4] -- +│ │ └─ModuleList: 3-3 -- (recursive) +│ │ │ └─Linear: 4-1 [1, 4] 20 +│ │ └─ReLU: 3-2 [1, 4] -- +│ │ └─ModuleList: 3-3 -- (recursive) +│ │ │ └─Linear: 4-2 [1, 4] 20 +│ │ └─ReLU: 3-4 [1, 4] -- +│ └─Block: 2-2 [1, 4] -- +│ │ └─ModuleList: 3-7 -- (recursive) +│ │ │ └─Linear: 4-3 [1, 4] 20 +│ │ └─ReLU: 3-6 [1, 4] -- +│ │ └─ModuleList: 3-7 -- (recursive) +│ │ │ └─Linear: 4-4 [1, 4] 20 +│ │ └─ReLU: 3-8 [1, 4] -- +│ └─Block: 2-3 [1, 4] -- +│ │ └─ModuleList: 3-11 -- (recursive) +│ │ │ └─Linear: 4-5 [1, 4] 20 +│ │ └─ReLU: 3-10 [1, 4] -- +│ │ └─ModuleList: 3-11 -- (recursive) +│ │ │ └─Linear: 4-6 [1, 4] 20 +│ │ └─ReLU: 3-12 [1, 4] -- +========================================================================================== +Total params: 120 +Trainable params: 120 +Non-trainable params: 0 +Total mult-adds: 120 +========================================================================================== +Input size: 72 +Forward/backward pass size: 192 +Params size: 480 +Estimated Total Size: 744 +========================================================================================== diff --git a/tests/torchinfo_test.py b/tests/torchinfo_test.py index 64675f7..5f63e21 100644 --- a/tests/torchinfo_test.py +++ b/tests/torchinfo_test.py @@ -39,6 +39,7 @@ ReuseLinearExtended, ReuseReLU, ScalarTensorInputNet, + SharedModuleInNestedList, SiameseNets, SimpleRNN, SingleInputNet, @@ -641,6 +642,21 @@ def test_cnn_module_list() -> None: summary(CNNModuleList(ConvLayerB), input_size=[1, 1, 10]) +def test_shared_module_in_nested_list() -> None: + # Regression test for #327: a module instance shared across several parents + # (here one ReLU reused by every block) combined with nested ModuleLists + # previously corrupted the hierarchy and double-counted parameters. + model = SharedModuleInNestedList() + # depth=4 so the shared Linear leaves are shown and the Param# column is + # visibly accounted for (each counted once, not once per parent). + result = summary(model, input_size=(1, 4), depth=4) + + # 3 blocks * 2 Linear(4, 4) = 6 layers, each 4*4 + 4 = 20 params -> 120 total + expected = sum(p.numel() for p in model.parameters() if p.requires_grad) + assert result.total_params == expected == 120 + assert result.trainable_params == 120 + + def test_hide_recursive_layers() -> None: model = SimpleRNN() summary(model, input_size=(2, 3)) diff --git a/torchinfo/layer_info.py b/torchinfo/layer_info.py index bd2a16f..44ecdb0 100644 --- a/torchinfo/layer_info.py +++ b/torchinfo/layer_info.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterable, Sequence -from typing import Any +from typing import Any, cast import numpy as np import torch @@ -305,36 +305,59 @@ def params_percent( return zero return f"{params / total_params:>{precision + spacing}.{precision}%}" + def _leftover(self, attr: str) -> int: + """ + Shared implementation for the leftover_* properties: the amount of `attr` + (num_params / trainable_params / param_bytes) owned directly by this layer + and not already accounted for by a descendant row in the summary. + + Each distinct child module is subtracted at most once (keyed by layer_id), + which matters when a module is reused: it appears in this layer's + named_parameters only once, but shows up as several rows. A recursive + child's whole subtree is counted at its non-recursive occurrence (here or + under another parent), so its `attr` is subtracted once and its subtree is + skipped. + """ + total = cast("int", getattr(self, attr)) + children = self.children + seen: set[int] = set() + i, n = 0, len(children) + while i < n: + child = children[i] + if child.is_recursive: + if child.layer_id not in seen: + seen.add(child.layer_id) + total -= cast("int", getattr(child, attr)) + # Skip the rest of the recursive subtree; it is part of `child`. + i += 1 + while i < n and children[i].depth > child.depth: + i += 1 + continue + seen.add(child.layer_id) + total -= ( + cast("int", getattr(child, attr)) + if child.is_leaf_layer + else child._leftover(attr) + ) + i += 1 + return total + def leftover_params(self) -> int: """ Leftover params are the number of params this current layer has that are not included in the child num_param counts. """ - return self.num_params - sum( - child.num_params if child.is_leaf_layer else child.leftover_params() - for child in self.children - if not child.is_recursive - ) + return self._leftover("num_params") def leftover_trainable_params(self) -> int: - return self.trainable_params - sum( - child.trainable_params - if child.is_leaf_layer - else child.leftover_trainable_params() - for child in self.children - if not child.is_recursive - ) + return self._leftover("trainable_params") def leftover_param_bytes(self) -> int: """ Leftover param bytes are the bytes of this layer's params that are not included in the child param_bytes counts. """ - return self.param_bytes - sum( - child.param_bytes if child.is_leaf_layer else child.leftover_param_bytes() - for child in self.children - if not child.is_recursive - ) + return self._leftover("param_bytes") def nested_list_size(inputs: Sequence[Any] | torch.Tensor) -> tuple[list[int], int]: diff --git a/torchinfo/torchinfo.py b/torchinfo/torchinfo.py index 3eb19c8..e79b623 100644 --- a/torchinfo/torchinfo.py +++ b/torchinfo/torchinfo.py @@ -571,29 +571,66 @@ def get_correct_input_sizes(input_size: INPUT_SIZE_TYPE) -> CORRECTED_INPUT_SIZE return [input_size] +LayerContext = tuple[str, int, "LayerInfo | None"] + + +def resolve_layer_context( + contexts: list[LayerContext], module_stack: list[LayerInfo] +) -> LayerContext: + """ + Pick the structural context (var_name, depth, parent) matching the module's + actual parent during the current forward call. + + A module instance shared by several parents is visited once per parent during + traversal, producing multiple contexts. Picking statically (as the last one + registered) corrupts the hierarchy for all the other call sites. The correct + context is the one whose nearest executing ancestor is the module currently on + top of the runtime stack. + """ + if len(contexts) == 1 or not module_stack: + return contexts[0] + top = module_stack[-1] + on_stack = {id(frame.module) for frame in module_stack} + for context in contexts: + parent = context[2] + # Walk up past containers (ModuleList/ModuleDict have no forward and never + # appear on the runtime stack) to the nearest currently-executing ancestor. + while parent is not None and id(parent.module) not in on_stack: + parent = parent.parent_info + if parent is not None and parent.module is top.module: + return context + return contexts[0] + + def construct_pre_hook( global_layer_info: dict[int, LayerInfo], summary_list: list[LayerInfo], layer_ids: set[int], - var_name: str, - curr_depth: int, - parent_info: LayerInfo | None, + module_contexts: dict[int, list[LayerContext]], + module_stack: list[LayerInfo], + module_id: int, ) -> Callable[[nn.Module, Any], None]: def pre_hook(module: nn.Module, inputs: Any) -> None: """Create a LayerInfo object to aggregate layer information.""" del inputs + var_name, curr_depth, parent_info = resolve_layer_context( + module_contexts[module_id], module_stack + ) info = LayerInfo(var_name, module, curr_depth, parent_info) info.calculate_num_params() info.check_recursive(layer_ids) summary_list.append(info) layer_ids.add(info.layer_id) global_layer_info[info.layer_id] = info + module_stack.append(info) return pre_hook def construct_hook( - global_layer_info: dict[int, LayerInfo], batch_dim: int | None + global_layer_info: dict[int, LayerInfo], + module_stack: list[LayerInfo], + batch_dim: int | None, ) -> Callable[[nn.Module, Any, Any], None]: def hook(module: nn.Module, inputs: Any, outputs: Any) -> None: """Update LayerInfo after forward pass.""" @@ -605,6 +642,10 @@ def hook(module: nn.Module, inputs: Any, outputs: Any) -> None: info.output_bytes = elem_bytes * prod(info.output_size) info.executed = True info.calculate_macs() + # Pop the frame pushed by this module's pre_hook. Hooks fire LIFO, so the + # top of the stack is always this module's own info. + if module_stack: + module_stack.pop() return hook @@ -628,6 +669,11 @@ def apply_hooks( layer_ids: set[int] = set() # Used to optimize is_recursive() global_layer_info: dict[int, LayerInfo] = {} hooks: dict[int, tuple[RemovableHandle, RemovableHandle]] = {} + # All structural contexts each module is reached through. A module shared by + # several parents has more than one; the correct one is resolved at runtime. + module_contexts: dict[int, list[LayerContext]] = {} + # Stack of currently-executing layers, maintained by the pre/post hooks. + module_stack: list[LayerInfo] = [] stack: list[tuple[str, nn.Module, int, LayerInfo | None]] = [ (model_name, module, 0, None) ] @@ -640,25 +686,37 @@ def apply_hooks( global_layer_info[module_id] = LayerInfo( var_name, module, curr_depth, parent_info ) - pre_hook = construct_pre_hook( - global_layer_info, - summary_list, - layer_ids, - var_name, - curr_depth, - parent_info, + module_contexts.setdefault(module_id, []).append( + (var_name, curr_depth, parent_info) ) + if input_data is None or isinstance(module, WRAPPER_MODULES): - pre_hook(module, None) + # No forward pass is run, so build the layer directly from this + # occurrence's context rather than resolving it dynamically. + info = LayerInfo(var_name, module, curr_depth, parent_info) + info.calculate_num_params() + info.check_recursive(layer_ids) + summary_list.append(info) + layer_ids.add(info.layer_id) + global_layer_info[info.layer_id] = info else: # Register the hook using the last layer that uses this module. if module_id in hooks: for hook in hooks[module_id]: hook.remove() hooks[module_id] = ( - module.register_forward_pre_hook(pre_hook), + module.register_forward_pre_hook( + construct_pre_hook( + global_layer_info, + summary_list, + layer_ids, + module_contexts, + module_stack, + module_id, + ) + ), module.register_forward_hook( - construct_hook(global_layer_info, batch_dim) + construct_hook(global_layer_info, module_stack, batch_dim) ), ) From a87f465967202de582e5e48a04db1f8737d8009e Mon Sep 17 00:00:00 2001 From: vietlh Date: Sun, 7 Jun 2026 20:48:45 +0700 Subject: [PATCH 2/4] ci: re-trigger workflow (flaky HF 429 in test_flan_t5_small) Co-Authored-By: Claude Opus 4.8 (1M context) From 40ba2f02aa15af4ea96b8e0748dad94141d260e4 Mon Sep 17 00:00:00 2001 From: vietlh Date: Sat, 13 Jun 2026 13:51:57 +0700 Subject: [PATCH 3/4] refactor: address review feedback on shared-module fix - Move the LayerContext type alias to the top of the file alongside the other type definitions (review: "move new types to the top of the file"). - Guard the post-hook stack pop by module identity instead of blindly popping the top frame, so an exotic re-entrant hook can never corrupt another module's frame. - Document the LIFO hook-firing assumption and its failure modes at the pop site. Co-Authored-By: Claude Opus 4.8 (1M context) --- torchinfo/torchinfo.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/torchinfo/torchinfo.py b/torchinfo/torchinfo.py index e79b623..309675e 100644 --- a/torchinfo/torchinfo.py +++ b/torchinfo/torchinfo.py @@ -29,6 +29,11 @@ INPUT_SIZE_TYPE = Sequence[int | Sequence[Any] | torch.Size] CORRECTED_INPUT_SIZE_TYPE = list[Sequence[Any] | torch.Size] +# A structural context a module is reached through during traversal: +# (var_name, depth, parent_info). A module shared by several parents has one per +# parent; the correct one is resolved at runtime by resolve_layer_context(). +LayerContext = tuple[str, int, "LayerInfo | None"] + DEFAULT_COLUMN_NAMES = (ColumnSettings.OUTPUT_SIZE, ColumnSettings.NUM_PARAMS) DEFAULT_ROW_SETTINGS = {RowSettings.DEPTH} REQUIRES_INPUT = { @@ -571,9 +576,6 @@ def get_correct_input_sizes(input_size: INPUT_SIZE_TYPE) -> CORRECTED_INPUT_SIZE return [input_size] -LayerContext = tuple[str, int, "LayerInfo | None"] - - def resolve_layer_context( contexts: list[LayerContext], module_stack: list[LayerInfo] ) -> LayerContext: @@ -642,9 +644,21 @@ def hook(module: nn.Module, inputs: Any, outputs: Any) -> None: info.output_bytes = elem_bytes * prod(info.output_size) info.executed = True info.calculate_macs() - # Pop the frame pushed by this module's pre_hook. Hooks fire LIFO, so the - # top of the stack is always this module's own info. - if module_stack: + # Pop the frame pushed by this module's pre_hook. + # + # LIFO assumption: a module's forward calls its children's forwards, so by + # Python call-stack nesting the parent's post-hook always fires after every + # child's post-hook. The frame this module pushed in its pre_hook is + # therefore on top of the stack here. + # + # Failure modes this guards against: the assumption can only be perturbed by + # an exotic setup (e.g. a custom hook that re-enters forward, or a child + # whose post-hook is somehow skipped). Rather than blindly popping — which + # would corrupt another module's frame — we pop only when the top frame is + # this module's own. On a mismatch we leave the stack untouched; the worst + # case is a slightly stale stack used for parent resolution, never a wrong + # pop. + if module_stack and module_stack[-1].module is module: module_stack.pop() return hook From 42b0ac383132b52255e2c0b47fea249600c02b1a Mon Sep 17 00:00:00 2001 From: vietlh Date: Sun, 14 Jun 2026 17:00:01 +0700 Subject: [PATCH 4/4] feat: warn on unexpected forward-hook order in post-hook Per review: keep the fail-safe guarded pop, but emit a warning when the top stack frame is not this module's own, so a (rare) re-entrant/exotic hook setup that could mislabel the shared-module hierarchy is visible instead of silent. Co-Authored-By: Claude Opus 4.8 (1M context) --- torchinfo/torchinfo.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/torchinfo/torchinfo.py b/torchinfo/torchinfo.py index 309675e..50a22a8 100644 --- a/torchinfo/torchinfo.py +++ b/torchinfo/torchinfo.py @@ -655,11 +655,22 @@ def hook(module: nn.Module, inputs: Any, outputs: Any) -> None: # an exotic setup (e.g. a custom hook that re-enters forward, or a child # whose post-hook is somehow skipped). Rather than blindly popping — which # would corrupt another module's frame — we pop only when the top frame is - # this module's own. On a mismatch we leave the stack untouched; the worst - # case is a slightly stale stack used for parent resolution, never a wrong - # pop. + # this module's own. On a mismatch we leave the stack untouched (fail-safe: + # the worst case is a slightly stale stack used for parent resolution of a + # shared module, never a wrong pop) and warn, since it may indicate a parent + # row is mislabeled in the printed tree. if module_stack and module_stack[-1].module is module: module_stack.pop() + else: + top = module_stack[-1].module if module_stack else None + warnings.warn( + f"Unexpected forward-hook order: expected {type(module).__name__} " + f"on top of the execution stack, found " + f"{type(top).__name__ if top is not None else 'empty stack'}. " + "Skipping the stack pop to stay safe; the layer hierarchy for " + "shared modules may be slightly inaccurate.", + stacklevel=2, + ) return hook