Skip to content

fix: handle torch.distributions.Distribution outputs in summary() (#329)#394

Merged
TylerYep merged 3 commits into
TylerYep:mainfrom
Mikyx-1:fix/issue-329
Jun 7, 2026
Merged

fix: handle torch.distributions.Distribution outputs in summary() (#329)#394
TylerYep merged 3 commits into
TylerYep:mainfrom
Mikyx-1:fix/issue-329

Conversation

@Mikyx-1

@Mikyx-1 Mikyx-1 commented Jun 6, 2026

Copy link
Copy Markdown
Contributor

Bug

torchinfo.summary() raises a TypeError when a model's forward() returns a torch.distributions.Distribution object (e.g. Categorical, Normal) instead of a plain tensor.

Root cause: LayerInfo.calculate_size() in layer_info.py handles Tensor, ndarray, dict, list, and tuple, but falls through to a hard raise TypeError for any other type — including distribution objects.

Reproduction (from #329):

import torch, torch.nn as nn, torchinfo
from torch.distributions import Categorical

class ProbabilisticModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)
        self.final = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return Categorical(logits=self.final(self.linear(x)))

torchinfo.summary(ProbabilisticModel(), input_size=(1, 10))
# TypeError: Model contains a layer with an unsupported input or output type

Fix

Add an elif branch in calculate_size() for torch.distributions.Distribution:

elif isinstance(inputs, torch.distributions.Distribution):
    size = list(inputs.batch_shape + inputs.event_shape)
    elem_bytes = 0
  • batch_shape + event_shape gives the meaningful output shape (e.g. [1] for Categorical with a batch of 1).
  • elem_bytes = 0 since there is no single underlying tensor to measure.
  • Covers all Distribution subclasses (Categorical, Normal, Bernoulli, MultivariateNormal, etc.), not just the one from the issue report.
  • The existing else: raise TypeError is preserved for truly unknown types.

Tests

Closes #329

🤖 Generated with Claude Code

Mikyx-1 and others added 2 commits June 6, 2026 11:30
Models that return a distribution object (e.g. Categorical) from their
forward pass previously raised a TypeError. Now batch_shape + event_shape
is used as the output size with elem_bytes=0, covering all Distribution
subclasses.

Fixes TylerYep#329

Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
Adds CategoricalOutputModel and NormalOutputModel fixtures and two
tests that verify summary() completes and reports correct shapes for
models returning torch.distributions.Categorical and Normal.

Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
@TylerYep TylerYep merged commit 3a0de39 into TylerYep:main Jun 7, 2026
4 checks passed
@TylerYep

TylerYep commented Jun 7, 2026

Copy link
Copy Markdown
Owner

Thank you for the fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Issue with torchinfo.summary() Failing on Models with torch.distributions.Categorical

2 participants