Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/fixtures/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, self.model(x))


class SharedModuleInNestedList(nn.Module):
"""Reproduces #327: a single module instance shared across several parents,
combined with nested ModuleLists.

Each block holds its layers in a ModuleList and reuses one shared activation
that is also referenced by its sibling blocks. Recording a single parent for
the shared module corrupted the layer hierarchy and double-counted params.
"""

class Block(nn.Module):
def __init__(self, act: nn.Module) -> None:
super().__init__()
self.layers = nn.ModuleList([nn.Linear(4, 4) for _ in range(2)])
self.act = act # shared instance, passed in from the parent

def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = self.act(layer(x))
return x

def __init__(self) -> None:
super().__init__()
shared_act = nn.ReLU()
self.blocks = nn.ModuleList(
[SharedModuleInNestedList.Block(shared_act) for _ in range(3)]
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
for block in self.blocks:
x = block(x)
return x


class PrunedLayerNameModel(nn.Module):
"""Model that defines parameters with _orig and _mask as suffixes."""

Expand Down
37 changes: 37 additions & 0 deletions tests/test_output/shared_module_in_nested_list.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
SharedModuleInNestedList [1, 4] --
├─ModuleList: 1-1 -- --
│ └─Block: 2-1 [1, 4] --
│ │ └─ModuleList: 3-3 -- (recursive)
│ │ │ └─Linear: 4-1 [1, 4] 20
│ │ └─ReLU: 3-2 [1, 4] --
│ │ └─ModuleList: 3-3 -- (recursive)
│ │ │ └─Linear: 4-2 [1, 4] 20
│ │ └─ReLU: 3-4 [1, 4] --
│ └─Block: 2-2 [1, 4] --
│ │ └─ModuleList: 3-7 -- (recursive)
│ │ │ └─Linear: 4-3 [1, 4] 20
│ │ └─ReLU: 3-6 [1, 4] --
│ │ └─ModuleList: 3-7 -- (recursive)
│ │ │ └─Linear: 4-4 [1, 4] 20
│ │ └─ReLU: 3-8 [1, 4] --
│ └─Block: 2-3 [1, 4] --
│ │ └─ModuleList: 3-11 -- (recursive)
│ │ │ └─Linear: 4-5 [1, 4] 20
│ │ └─ReLU: 3-10 [1, 4] --
│ │ └─ModuleList: 3-11 -- (recursive)
│ │ │ └─Linear: 4-6 [1, 4] 20
│ │ └─ReLU: 3-12 [1, 4] --
==========================================================================================
Total params: 120
Trainable params: 120
Non-trainable params: 0
Total mult-adds: 120
==========================================================================================
Input size: 72
Forward/backward pass size: 192
Params size: 480
Estimated Total Size: 744
==========================================================================================
16 changes: 16 additions & 0 deletions tests/torchinfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
ReuseLinearExtended,
ReuseReLU,
ScalarTensorInputNet,
SharedModuleInNestedList,
SiameseNets,
SimpleRNN,
SingleInputNet,
Expand Down Expand Up @@ -641,6 +642,21 @@ def test_cnn_module_list() -> None:
summary(CNNModuleList(ConvLayerB), input_size=[1, 1, 10])


def test_shared_module_in_nested_list() -> None:
# Regression test for #327: a module instance shared across several parents
# (here one ReLU reused by every block) combined with nested ModuleLists
# previously corrupted the hierarchy and double-counted parameters.
model = SharedModuleInNestedList()
# depth=4 so the shared Linear leaves are shown and the Param# column is
# visibly accounted for (each counted once, not once per parent).
result = summary(model, input_size=(1, 4), depth=4)

# 3 blocks * 2 Linear(4, 4) = 6 layers, each 4*4 + 4 = 20 params -> 120 total
expected = sum(p.numel() for p in model.parameters() if p.requires_grad)
assert result.total_params == expected == 120
assert result.trainable_params == 120


def test_hide_recursive_layers() -> None:
model = SimpleRNN()
summary(model, input_size=(2, 3))
Expand Down
59 changes: 41 additions & 18 deletions torchinfo/layer_info.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Iterable, Sequence
from typing import Any
from typing import Any, cast

import numpy as np
import torch
Expand Down Expand Up @@ -305,36 +305,59 @@ def params_percent(
return zero
return f"{params / total_params:>{precision + spacing}.{precision}%}"

def _leftover(self, attr: str) -> int:
"""
Shared implementation for the leftover_* properties: the amount of `attr`
(num_params / trainable_params / param_bytes) owned directly by this layer
and not already accounted for by a descendant row in the summary.

Each distinct child module is subtracted at most once (keyed by layer_id),
which matters when a module is reused: it appears in this layer's
named_parameters only once, but shows up as several rows. A recursive
child's whole subtree is counted at its non-recursive occurrence (here or
under another parent), so its `attr` is subtracted once and its subtree is
skipped.
"""
total = cast("int", getattr(self, attr))
children = self.children
seen: set[int] = set()
i, n = 0, len(children)
while i < n:
child = children[i]
if child.is_recursive:
if child.layer_id not in seen:
seen.add(child.layer_id)
total -= cast("int", getattr(child, attr))
# Skip the rest of the recursive subtree; it is part of `child`.
i += 1
while i < n and children[i].depth > child.depth:
i += 1
continue
seen.add(child.layer_id)
total -= (
cast("int", getattr(child, attr))
if child.is_leaf_layer
else child._leftover(attr)
)
i += 1
return total

def leftover_params(self) -> int:
"""
Leftover params are the number of params this current layer has that are not
included in the child num_param counts.
"""
return self.num_params - sum(
child.num_params if child.is_leaf_layer else child.leftover_params()
for child in self.children
if not child.is_recursive
)
return self._leftover("num_params")

def leftover_trainable_params(self) -> int:
return self.trainable_params - sum(
child.trainable_params
if child.is_leaf_layer
else child.leftover_trainable_params()
for child in self.children
if not child.is_recursive
)
return self._leftover("trainable_params")

def leftover_param_bytes(self) -> int:
"""
Leftover param bytes are the bytes of this layer's params that are not
included in the child param_bytes counts.
"""
return self.param_bytes - sum(
child.param_bytes if child.is_leaf_layer else child.leftover_param_bytes()
for child in self.children
if not child.is_recursive
)
return self._leftover("param_bytes")


def nested_list_size(inputs: Sequence[Any] | torch.Tensor) -> tuple[list[int], int]:
Expand Down
111 changes: 97 additions & 14 deletions torchinfo/torchinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
INPUT_SIZE_TYPE = Sequence[int | Sequence[Any] | torch.Size]
CORRECTED_INPUT_SIZE_TYPE = list[Sequence[Any] | torch.Size]

# A structural context a module is reached through during traversal:
# (var_name, depth, parent_info). A module shared by several parents has one per
# parent; the correct one is resolved at runtime by resolve_layer_context().
LayerContext = tuple[str, int, "LayerInfo | None"]

DEFAULT_COLUMN_NAMES = (ColumnSettings.OUTPUT_SIZE, ColumnSettings.NUM_PARAMS)
DEFAULT_ROW_SETTINGS = {RowSettings.DEPTH}
REQUIRES_INPUT = {
Expand Down Expand Up @@ -571,29 +576,63 @@ def get_correct_input_sizes(input_size: INPUT_SIZE_TYPE) -> CORRECTED_INPUT_SIZE
return [input_size]


def resolve_layer_context(
contexts: list[LayerContext], module_stack: list[LayerInfo]
) -> LayerContext:
"""
Pick the structural context (var_name, depth, parent) matching the module's
actual parent during the current forward call.

A module instance shared by several parents is visited once per parent during
traversal, producing multiple contexts. Picking statically (as the last one
registered) corrupts the hierarchy for all the other call sites. The correct
context is the one whose nearest executing ancestor is the module currently on
top of the runtime stack.
"""
if len(contexts) == 1 or not module_stack:
return contexts[0]
top = module_stack[-1]
on_stack = {id(frame.module) for frame in module_stack}
for context in contexts:
parent = context[2]
# Walk up past containers (ModuleList/ModuleDict have no forward and never
# appear on the runtime stack) to the nearest currently-executing ancestor.
while parent is not None and id(parent.module) not in on_stack:
parent = parent.parent_info
if parent is not None and parent.module is top.module:
return context
return contexts[0]


def construct_pre_hook(
global_layer_info: dict[int, LayerInfo],
summary_list: list[LayerInfo],
layer_ids: set[int],
var_name: str,
curr_depth: int,
parent_info: LayerInfo | None,
module_contexts: dict[int, list[LayerContext]],
module_stack: list[LayerInfo],
module_id: int,
) -> Callable[[nn.Module, Any], None]:
def pre_hook(module: nn.Module, inputs: Any) -> None:
"""Create a LayerInfo object to aggregate layer information."""
del inputs
var_name, curr_depth, parent_info = resolve_layer_context(
module_contexts[module_id], module_stack
)
info = LayerInfo(var_name, module, curr_depth, parent_info)
info.calculate_num_params()
info.check_recursive(layer_ids)
summary_list.append(info)
layer_ids.add(info.layer_id)
global_layer_info[info.layer_id] = info
module_stack.append(info)

return pre_hook


def construct_hook(
global_layer_info: dict[int, LayerInfo], batch_dim: int | None
global_layer_info: dict[int, LayerInfo],
module_stack: list[LayerInfo],
batch_dim: int | None,
) -> Callable[[nn.Module, Any, Any], None]:
def hook(module: nn.Module, inputs: Any, outputs: Any) -> None:
"""Update LayerInfo after forward pass."""
Expand All @@ -605,6 +644,33 @@ def hook(module: nn.Module, inputs: Any, outputs: Any) -> None:
info.output_bytes = elem_bytes * prod(info.output_size)
info.executed = True
info.calculate_macs()
# Pop the frame pushed by this module's pre_hook.
#
# LIFO assumption: a module's forward calls its children's forwards, so by
# Python call-stack nesting the parent's post-hook always fires after every
# child's post-hook. The frame this module pushed in its pre_hook is
# therefore on top of the stack here.
#
# Failure modes this guards against: the assumption can only be perturbed by
# an exotic setup (e.g. a custom hook that re-enters forward, or a child
# whose post-hook is somehow skipped). Rather than blindly popping — which
# would corrupt another module's frame — we pop only when the top frame is
# this module's own. On a mismatch we leave the stack untouched (fail-safe:
# the worst case is a slightly stale stack used for parent resolution of a
# shared module, never a wrong pop) and warn, since it may indicate a parent
# row is mislabeled in the printed tree.
if module_stack and module_stack[-1].module is module:
module_stack.pop()
else:
top = module_stack[-1].module if module_stack else None
warnings.warn(
f"Unexpected forward-hook order: expected {type(module).__name__} "
f"on top of the execution stack, found "
f"{type(top).__name__ if top is not None else 'empty stack'}. "
"Skipping the stack pop to stay safe; the layer hierarchy for "
"shared modules may be slightly inaccurate.",
stacklevel=2,
)

return hook

Expand All @@ -628,6 +694,11 @@ def apply_hooks(
layer_ids: set[int] = set() # Used to optimize is_recursive()
global_layer_info: dict[int, LayerInfo] = {}
hooks: dict[int, tuple[RemovableHandle, RemovableHandle]] = {}
# All structural contexts each module is reached through. A module shared by
# several parents has more than one; the correct one is resolved at runtime.
module_contexts: dict[int, list[LayerContext]] = {}
# Stack of currently-executing layers, maintained by the pre/post hooks.
module_stack: list[LayerInfo] = []
stack: list[tuple[str, nn.Module, int, LayerInfo | None]] = [
(model_name, module, 0, None)
]
Expand All @@ -640,25 +711,37 @@ def apply_hooks(
global_layer_info[module_id] = LayerInfo(
var_name, module, curr_depth, parent_info
)
pre_hook = construct_pre_hook(
global_layer_info,
summary_list,
layer_ids,
var_name,
curr_depth,
parent_info,
module_contexts.setdefault(module_id, []).append(
(var_name, curr_depth, parent_info)
)

if input_data is None or isinstance(module, WRAPPER_MODULES):
pre_hook(module, None)
# No forward pass is run, so build the layer directly from this
# occurrence's context rather than resolving it dynamically.
info = LayerInfo(var_name, module, curr_depth, parent_info)
info.calculate_num_params()
info.check_recursive(layer_ids)
summary_list.append(info)
layer_ids.add(info.layer_id)
global_layer_info[info.layer_id] = info
else:
# Register the hook using the last layer that uses this module.
if module_id in hooks:
for hook in hooks[module_id]:
hook.remove()
hooks[module_id] = (
module.register_forward_pre_hook(pre_hook),
module.register_forward_pre_hook(
construct_pre_hook(
global_layer_info,
summary_list,
layer_ids,
module_contexts,
module_stack,
module_id,
)
),
module.register_forward_hook(
construct_hook(global_layer_info, batch_dim)
construct_hook(global_layer_info, module_stack, batch_dim)
),
)

Expand Down
Loading