From 697d513837d7977237eb06ac01c8d85f198b3af2 Mon Sep 17 00:00:00 2001 From: ChanifRusydi <62838892+ChanifRusydi@users.noreply.github.com> Date: Tue, 16 Sep 2025 19:38:13 +0700 Subject: [PATCH 1/5] Fixing #371 and tried adding directml support #368 --- torchinfo/torchinfo.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/torchinfo/torchinfo.py b/torchinfo/torchinfo.py index c0b6bb5b..919772d9 100644 --- a/torchinfo/torchinfo.py +++ b/torchinfo/torchinfo.py @@ -16,10 +16,13 @@ ) import numpy as np + import torch from torch import nn from torch.jit import ScriptModule +from torch.return_types import mode from torch.utils.hooks import RemovableHandle +from packaging import version from .enums import ColumnSettings, Mode, RowSettings, Verbosity from .formatting import FormattingOptions @@ -473,8 +476,17 @@ def get_device( If input_data is given, the device should not be changed (to allow for multi-device models, etc.) - Otherwise gets device of first parameter of model and returns it if it is on cuda, - otherwise returns cuda if available or cpu if not. + Otherwise gets device of first parameter of model and returns it. + As #371 mentioned, in system with other accelerator (such as MPS), + returns the accelerator (introduced in pytorch 2.6.0) device if available, + otherwise returns cuda if cuda is available, if not then return cpu. + + Attempting to also add DirectML device (may include AMD GPU support). + However since personally didn't have any AMD GPU at the moment, + the implementation is solely based on documentation and examples. + DirectML documentation and example: + https://github.com/microsoft/DirectML + https://medium.com/@ochwada/preparations-for-pytorch-and-directml-on-amd-ryzen-9-6950h-for-ai-projects-15c164d22332 """ if input_data is None: try: @@ -482,9 +494,23 @@ def get_device( except StopIteration: model_parameter = None - if model_parameter is not None and model_parameter.is_cuda: + if model_parameter is not None and model_parameter.device: return model_parameter.device - return torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Since torch.accelerator is available in torch 2.6.0 and above + if version.parse(torch.__version__) >= version.parse("2.6.0"): + try: + if torch.accelerator.is_available(): + return torch.accelerator.current_accelerator() + except Exception as e: + print(f"Error occurred while getting device: {e}") + try: + import torch_directml + print("DirectML plugin is installed, using DirectML device.") + if model_parameter is not None: + return torch_directml.device(torch_directml.default_device()) + except ImportError: + print("DirectML plugin is not installed.") + return torch.device("cpu") return None From 1f05e1dc2fb2568badd3019bd98f79d52391b889 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Sep 2025 12:50:48 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchinfo/torchinfo.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchinfo/torchinfo.py b/torchinfo/torchinfo.py index 919772d9..b20a4794 100644 --- a/torchinfo/torchinfo.py +++ b/torchinfo/torchinfo.py @@ -16,13 +16,11 @@ ) import numpy as np - import torch +from packaging import version from torch import nn from torch.jit import ScriptModule -from torch.return_types import mode from torch.utils.hooks import RemovableHandle -from packaging import version from .enums import ColumnSettings, Mode, RowSettings, Verbosity from .formatting import FormattingOptions @@ -497,7 +495,7 @@ def get_device( if model_parameter is not None and model_parameter.device: return model_parameter.device # Since torch.accelerator is available in torch 2.6.0 and above - if version.parse(torch.__version__) >= version.parse("2.6.0"): + if version.parse(torch.__version__) >= version.parse("2.6.0"): try: if torch.accelerator.is_available(): return torch.accelerator.current_accelerator() @@ -505,6 +503,7 @@ def get_device( print(f"Error occurred while getting device: {e}") try: import torch_directml + print("DirectML plugin is installed, using DirectML device.") if model_parameter is not None: return torch_directml.device(torch_directml.default_device()) From b7dcc55b7c09ab326879706dfd7d3ab44a658073 Mon Sep 17 00:00:00 2001 From: ChanifRusydi <62838892+ChanifRusydi@users.noreply.github.com> Date: Wed, 17 Sep 2025 19:47:54 +0700 Subject: [PATCH 3/5] Fixing mypy import error in CI --- torchinfo/torchinfo.py | 699 ----------------------------------------- 1 file changed, 699 deletions(-) delete mode 100644 torchinfo/torchinfo.py diff --git a/torchinfo/torchinfo.py b/torchinfo/torchinfo.py deleted file mode 100644 index b20a4794..00000000 --- a/torchinfo/torchinfo.py +++ /dev/null @@ -1,699 +0,0 @@ -from __future__ import annotations - -import sys -import warnings -from typing import ( - Any, - Callable, - Iterable, - Iterator, - List, - Mapping, - Optional, - Sequence, - Union, - cast, -) - -import numpy as np -import torch -from packaging import version -from torch import nn -from torch.jit import ScriptModule -from torch.utils.hooks import RemovableHandle - -from .enums import ColumnSettings, Mode, RowSettings, Verbosity -from .formatting import FormattingOptions -from .layer_info import LayerInfo, get_children_layers, prod -from .model_statistics import ModelStatistics - -# Some modules do the computation themselves using parameters -# or the parameters of children. Treat these as layers. -# TODO: figure out a test case for this -LAYER_MODULES = (torch.nn.MultiheadAttention,) -# These modules are not recorded during a forward pass. Handle them separately. -WRAPPER_MODULES = (ScriptModule,) - -INPUT_DATA_TYPE = Union[ - torch.Tensor, np.ndarray, Sequence[Any], Mapping[str, Any] # type: ignore[type-arg] -] -CORRECTED_INPUT_DATA_TYPE = Optional[Union[Iterable[Any], Mapping[Any, Any]]] -INPUT_SIZE_TYPE = Sequence[Union[int, Sequence[Any], torch.Size]] -CORRECTED_INPUT_SIZE_TYPE = List[Union[Sequence[Any], torch.Size]] - -DEFAULT_COLUMN_NAMES = (ColumnSettings.OUTPUT_SIZE, ColumnSettings.NUM_PARAMS) -DEFAULT_ROW_SETTINGS = {RowSettings.DEPTH} -REQUIRES_INPUT = { - ColumnSettings.INPUT_SIZE, - ColumnSettings.OUTPUT_SIZE, - ColumnSettings.MULT_ADDS, -} - -_cached_forward_pass: dict[str, list[LayerInfo]] = {} - - -def summary( - model: nn.Module, - input_size: INPUT_SIZE_TYPE | None = None, - input_data: INPUT_DATA_TYPE | None = None, - batch_dim: int | None = None, - cache_forward_pass: bool | None = None, - col_names: Iterable[str] | None = None, - col_width: int = 25, - depth: int = 3, - device: torch.device | str | None = None, - dtypes: list[torch.dtype] | None = None, - mode: str = "same", - row_settings: Iterable[str] | None = None, - verbose: int | None = None, - **kwargs: Any, -) -> ModelStatistics: - """ - Summarize the given PyTorch model. Summarized information includes: - 1) Layer names, - 2) input/output shapes, - 3) kernel shape, - 4) # of parameters, - 5) # of operations (Mult-Adds), - 6) whether layer is trainable - - NOTE: If neither input_data or input_size are provided, no forward pass through the - network is performed, and the provided model information is limited to layer names. - - Args: - model (nn.Module): - PyTorch model to summarize. The model should be fully in either train() - or eval() mode. If layers are not all in the same mode, running summary - may have side effects on batchnorm or dropout statistics. If you - encounter an issue with this, please open a GitHub issue. - - input_size (Sequence of Sizes): - Shape of input data as a List/Tuple/torch.Size - (dtypes must match model input, default is FloatTensors). - You should include batch size in the tuple. - Default: None - - input_data (Sequence of Tensors): - Arguments for the model's forward pass (dtypes inferred). - If the forward() function takes several parameters, pass in a list of - args or a dict of kwargs (if your forward() function takes in a dict - as its only argument, wrap it in a list). - Default: None - - batch_dim (int): - Batch_dimension of input data. If batch_dim is None, assume - input_data / input_size contains the batch dimension, which is used - in all calculations. Else, expand all tensors to contain the batch_dim. - Specifying batch_dim can be an runtime optimization, since if batch_dim - is specified, torchinfo uses a batch size of 1 for the forward pass. - Default: None - - cache_forward_pass (bool): - If True, cache the run of the forward() function using the model - class name as the key. If the forward pass is an expensive operation, - this can make it easier to modify the formatting of your model - summary, e.g. changing the depth or enabled column types, especially - in Jupyter Notebooks. - WARNING: Modifying the model architecture or input data/input size when - this feature is enabled does not invalidate the cache or re-run the - forward pass, and can cause incorrect summaries as a result. - Default: False - - col_names (Iterable[str]): - Specify which columns to show in the output. Currently supported: ( - "input_size", - "output_size", - "num_params", - "params_percent", - "kernel_size", - "groups", - "mult_adds", - "trainable", - ) - Default: ("output_size", "num_params") - If input_data / input_size are not provided, only "num_params" is used. - - col_width (int): - Width of each column. - Default: 25 - - depth (int): - Depth of nested layers to display (e.g. Sequentials). - Nested layers below this depth will not be displayed in the summary. - Default: 3 - - device (torch.Device): - Uses this torch device for model and input_data. - If not specified, uses the dtype of input_data if given, or the - parameters of the model. Otherwise, uses the result of - torch.cuda.is_available(). - Default: None - - dtypes (List[torch.dtype]): - If you use input_size, torchinfo assumes your input uses FloatTensors. - If your model use a different data type, specify that dtype. - For multiple inputs, specify the size of both inputs, and - also specify the types of each parameter here. - Default: None - - mode (str) - Either "train", "eval" or "same", which determines whether we call - model.train() or model.eval() before calling summary(). In any case, - original model mode is restored at the end. - Default: "same". - - row_settings (Iterable[str]): - Specify which features to show in a row. Currently supported: ( - "ascii_only", - "depth", - "var_names", - ) - Default: ("depth",) - - verbose (int): - 0 (quiet): No output - 1 (default): Print model summary - 2 (verbose): Show weight and bias layers in full detail - Default: 1 - If using a Juypter Notebook or Google Colab, the default is 0. - - **kwargs: - Other arguments used in `model.forward` function. Passing *args is no - longer supported. - - Return: - ModelStatistics object - See torchinfo/model_statistics.py for more information. - """ - input_data_specified = input_data is not None or input_size is not None - if col_names is None: - columns = ( - DEFAULT_COLUMN_NAMES - if input_data_specified - else (ColumnSettings.NUM_PARAMS,) - ) - else: - columns = tuple(ColumnSettings(name) for name in col_names) - - if row_settings is None: - rows = DEFAULT_ROW_SETTINGS - else: - rows = {RowSettings(name) for name in row_settings} - - model_mode = Mode(mode) - - if verbose is None: - verbose = 0 if hasattr(sys, "ps1") and sys.ps1 else 1 - - if cache_forward_pass is None: - # In the future, this may be enabled by default in Jupyter Notebooks - cache_forward_pass = False - - if device is None: - device = get_device(model, input_data) - elif isinstance(device, str): - device = torch.device(device) - - validate_user_params( - input_data, input_size, columns, col_width, device, dtypes, verbose - ) - - x, correct_input_size = process_input( - input_data, input_size, batch_dim, device, dtypes - ) - summary_list = forward_pass( - model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs - ) - formatting = FormattingOptions(depth, verbose, columns, col_width, rows) - results = ModelStatistics( - summary_list, correct_input_size, get_total_memory_used(x), formatting - ) - if verbose > Verbosity.QUIET: - print(results) - return results - - -def process_input( - input_data: INPUT_DATA_TYPE | None, - input_size: INPUT_SIZE_TYPE | None, - batch_dim: int | None, - device: torch.device | None, - dtypes: list[torch.dtype] | None = None, -) -> tuple[CORRECTED_INPUT_DATA_TYPE, Any]: - """Reads sample input data to get the input size.""" - x = None - correct_input_size = [] - if input_data is not None: - correct_input_size = get_input_data_sizes(input_data) - x = set_device(input_data, device) - if isinstance(x, (torch.Tensor, np.ndarray)): - x = [x] - - if input_size is not None: - assert device is not None - if dtypes is None: - dtypes = [torch.float] * len(input_size) - correct_input_size = get_correct_input_sizes(input_size) - x = get_input_tensor(correct_input_size, batch_dim, dtypes, device) - return x, correct_input_size - - -def forward_pass( - model: nn.Module, - x: CORRECTED_INPUT_DATA_TYPE, - batch_dim: int | None, - cache_forward_pass: bool, - device: torch.device | None, - mode: Mode, - **kwargs: Any, -) -> list[LayerInfo]: - """Perform a forward pass on the model using forward hooks.""" - global _cached_forward_pass - model_name = model.__class__.__name__ - if cache_forward_pass and model_name in _cached_forward_pass: - return _cached_forward_pass[model_name] - - summary_list, _, hooks = apply_hooks(model_name, model, x, batch_dim) - if x is None: - set_children_layers(summary_list) - return summary_list - - kwargs = set_device(kwargs, device) - saved_model_mode = model.training - try: - if mode == Mode.TRAIN: - model.train() - elif mode == Mode.EVAL: - model.eval() - elif mode != Mode.SAME: - raise RuntimeError( - f"Specified model mode ({list(Mode)}) not recognized: {mode}" - ) - - with torch.no_grad(): - model = model if device is None else model.to(device) - if isinstance(x, (list, tuple)): - _ = model(*x, **kwargs) - elif isinstance(x, dict): - _ = model(**x, **kwargs) - else: - # Should not reach this point, since process_input_data ensures - # x is either a list, tuple, or dict - raise ValueError("Unknown input type") - except Exception as e: - executed_layers = [layer for layer in summary_list if layer.executed] - raise RuntimeError( - "Failed to run torchinfo. See above stack traces for more details. " - f"Executed layers up to: {executed_layers}" - ) from e - finally: - if hooks: - for pre_hook, hook in hooks.values(): - pre_hook.remove() - hook.remove() - model.train(saved_model_mode) - - add_missing_container_layers(summary_list) - set_children_layers(summary_list) - - _cached_forward_pass[model_name] = summary_list - return summary_list - - -def set_children_layers(summary_list: list[LayerInfo]) -> None: - """Populates the children and depth_index fields of all LayerInfo.""" - idx: dict[int, int] = {} - for i, layer in enumerate(summary_list): - idx[layer.depth] = idx.get(layer.depth, 0) + 1 - layer.depth_index = idx[layer.depth] - layer.children = get_children_layers(summary_list, i) - - -def add_missing_container_layers(summary_list: list[LayerInfo]) -> None: - """Finds container modules not in the currently listed hierarchy.""" - layer_ids = {layer.layer_id for layer in summary_list} - current_hierarchy: dict[int, LayerInfo] = {} - for idx, layer_info in enumerate(summary_list): - # to keep track index of current layer - # after inserting new layers - rel_idx = 0 - - # create full hierarchy of current layer - hierarchy = {} - parent = layer_info.parent_info - while parent is not None and parent.depth > 0: - hierarchy[parent.depth] = parent - parent = parent.parent_info - - # show hierarchy if it is not there already - for d in range(1, layer_info.depth): - if ( - d not in current_hierarchy - or current_hierarchy[d].module is not hierarchy[d].module - ) and hierarchy[d] is not summary_list[idx + rel_idx - 1]: - hierarchy[d].calculate_num_params() - hierarchy[d].check_recursive(layer_ids) - summary_list.insert(idx + rel_idx, hierarchy[d]) - layer_ids.add(hierarchy[d].layer_id) - - current_hierarchy[d] = hierarchy[d] - rel_idx += 1 - - current_hierarchy[layer_info.depth] = layer_info - - # remove deeper hierarchy - d = layer_info.depth + 1 - while d in current_hierarchy: - current_hierarchy.pop(d) - d += 1 - - -def validate_user_params( - input_data: INPUT_DATA_TYPE | None, - input_size: INPUT_SIZE_TYPE | None, - col_names: tuple[ColumnSettings, ...], - col_width: int, - device: torch.device | None, - dtypes: list[torch.dtype] | None, - verbose: int, -) -> None: - """Raise exceptions if the user's input is invalid.""" - if col_width <= 0: - raise ValueError(f"Column width must be greater than 0: col_width={col_width}") - if verbose not in (0, 1, 2): - raise ValueError( - "Verbose must be either 0 (quiet), 1 (default), or 2 (verbose)." - ) - both_input_specified = input_data is not None and input_size is not None - if both_input_specified: - raise RuntimeError("Only one of (input_data, input_size) should be specified.") - - neither_input_specified = input_data is None and input_size is None - not_allowed = set(col_names) & REQUIRES_INPUT - if neither_input_specified and not_allowed: - raise ValueError( - "You must pass input_data or input_size in order " - f"to use columns: {not_allowed}" - ) - - if dtypes is not None and any( - dtype in (torch.float16, torch.bfloat16) for dtype in dtypes - ): - if input_size is not None: - warnings.warn( - "Half precision is not supported with input_size parameter, and may " - "output incorrect results. Try passing input_data directly.", - stacklevel=2, - ) - if device is not None and device.type == "cpu": - warnings.warn( - "Half precision is not supported on cpu. Set the `device` field or " - "pass `input_data` using the correct device.", - stacklevel=2, - ) - - -def traverse_input_data( - data: Any, action_fn: Callable[..., Any], aggregate_fn: Callable[..., Any] -) -> Any: - """ - Traverses any type of nested input data. On a tensor, returns the action given by - action_fn, and afterwards aggregates the results using aggregate_fn. - """ - if isinstance(data, torch.Tensor): - result = action_fn(data) - elif isinstance(data, np.ndarray): - result = action_fn(torch.from_numpy(data)) - # If the result of action_fn is a torch.Tensor, then action_fn was meant for - # torch.Tensors only (like calling .to(...)) -> Ignore. - if isinstance(result, torch.Tensor): - result = data - - # Recursively apply to collection items - elif isinstance(data, Mapping): - aggregate = aggregate_fn(data) - result = aggregate( - { - k: traverse_input_data(v, action_fn, aggregate_fn) - for k, v in data.items() - } - ) - elif isinstance(data, tuple) and hasattr(data, "_fields"): # Named tuple - aggregate = aggregate_fn(data) - result = aggregate( - *(traverse_input_data(d, action_fn, aggregate_fn) for d in data) - ) - elif isinstance(data, Iterable) and not isinstance(data, str): - aggregate = aggregate_fn(data) - result = aggregate( - [traverse_input_data(d, action_fn, aggregate_fn) for d in data] - ) - else: - # Data is neither a tensor nor a collection - result = data - return result - - -def set_device(data: Any, device: torch.device | None) -> Any: - """Sets device for all input types and collections of input types.""" - return ( - data - if device is None - else traverse_input_data( - data, - action_fn=lambda data: data.to(device, non_blocking=True), - aggregate_fn=type, - ) - ) - - -def get_device( - model: nn.Module, input_data: INPUT_DATA_TYPE | None -) -> torch.device | None: - """ - If input_data is given, the device should not be changed - (to allow for multi-device models, etc.) - - Otherwise gets device of first parameter of model and returns it. - As #371 mentioned, in system with other accelerator (such as MPS), - returns the accelerator (introduced in pytorch 2.6.0) device if available, - otherwise returns cuda if cuda is available, if not then return cpu. - - Attempting to also add DirectML device (may include AMD GPU support). - However since personally didn't have any AMD GPU at the moment, - the implementation is solely based on documentation and examples. - DirectML documentation and example: - https://github.com/microsoft/DirectML - https://medium.com/@ochwada/preparations-for-pytorch-and-directml-on-amd-ryzen-9-6950h-for-ai-projects-15c164d22332 - """ - if input_data is None: - try: - model_parameter = next(model.parameters()) - except StopIteration: - model_parameter = None - - if model_parameter is not None and model_parameter.device: - return model_parameter.device - # Since torch.accelerator is available in torch 2.6.0 and above - if version.parse(torch.__version__) >= version.parse("2.6.0"): - try: - if torch.accelerator.is_available(): - return torch.accelerator.current_accelerator() - except Exception as e: - print(f"Error occurred while getting device: {e}") - try: - import torch_directml - - print("DirectML plugin is installed, using DirectML device.") - if model_parameter is not None: - return torch_directml.device(torch_directml.default_device()) - except ImportError: - print("DirectML plugin is not installed.") - return torch.device("cpu") - return None - - -def get_input_data_sizes(data: Any) -> Any: - """ - Converts input data to an equivalent data structure of torch.Sizes - instead of tensors. - """ - return traverse_input_data( - data, action_fn=lambda data: data.size(), aggregate_fn=type - ) - - -def get_total_memory_used(data: CORRECTED_INPUT_DATA_TYPE) -> int: - """Calculates the total memory of all tensors stored in data.""" - result = traverse_input_data( - data, - action_fn=lambda data: sys.getsizeof( - data.untyped_storage() - if hasattr(data, "untyped_storage") - else data.storage() - ), - aggregate_fn=( - # We don't need the dictionary keys in this case - lambda data: (lambda d: sum(d.values())) - if isinstance(data, Mapping) - else sum - ), - ) - return cast(int, result) - - -def get_input_tensor( - input_size: CORRECTED_INPUT_SIZE_TYPE, - batch_dim: int | None, - dtypes: list[torch.dtype], - device: torch.device, -) -> list[torch.Tensor]: - """Get input_tensor with batch size 1 for use in model.forward()""" - x = [] - for size, dtype in zip(input_size, dtypes): - input_tensor = torch.rand(*size) - if batch_dim is not None: - input_tensor = input_tensor.unsqueeze(dim=batch_dim) - x.append(input_tensor.to(device).type(dtype)) - return x - - -def flatten(nested_array: INPUT_SIZE_TYPE) -> Iterator[Any]: - """Flattens a nested array.""" - for item in nested_array: - if isinstance(item, (list, tuple)): - yield from flatten(item) - else: - yield item - - -def get_correct_input_sizes(input_size: INPUT_SIZE_TYPE) -> CORRECTED_INPUT_SIZE_TYPE: - """ - Convert input_size to the correct form, which is a list of tuples. - Also handles multiple inputs to the network. - """ - if not isinstance(input_size, (list, tuple)): - raise TypeError( - "Input_size is not a recognized type. Please ensure input_size is valid.\n" - "For multiple inputs to the network, ensure input_size is a list of tuple " - "sizes. If you are having trouble here, please submit a GitHub issue." - ) - if not input_size or any(size <= 0 for size in flatten(input_size)): - raise ValueError("Input_data is invalid, or negative size found in input_data.") - - if isinstance(input_size, list) and isinstance(input_size[0], int): - return [tuple(input_size)] - if isinstance(input_size, list): - return input_size - if isinstance(input_size, tuple) and isinstance(input_size[0], tuple): - return list(input_size) - return [input_size] - - -def construct_pre_hook( - global_layer_info: dict[int, LayerInfo], - summary_list: list[LayerInfo], - layer_ids: set[int], - var_name: str, - curr_depth: int, - parent_info: LayerInfo | None, -) -> Callable[[nn.Module, Any], None]: - def pre_hook(module: nn.Module, inputs: Any) -> None: - """Create a LayerInfo object to aggregate layer information.""" - del inputs - info = LayerInfo(var_name, module, curr_depth, parent_info) - info.calculate_num_params() - info.check_recursive(layer_ids) - summary_list.append(info) - layer_ids.add(info.layer_id) - global_layer_info[info.layer_id] = info - - return pre_hook - - -def construct_hook( - global_layer_info: dict[int, LayerInfo], batch_dim: int | None -) -> Callable[[nn.Module, Any, Any], None]: - def hook(module: nn.Module, inputs: Any, outputs: Any) -> None: - """Update LayerInfo after forward pass.""" - info = global_layer_info[id(module)] - if info.contains_lazy_param: - info.calculate_num_params() - info.input_size, _ = info.calculate_size(inputs, batch_dim) - info.output_size, elem_bytes = info.calculate_size(outputs, batch_dim) - info.output_bytes = elem_bytes * prod(info.output_size) - info.executed = True - info.calculate_macs() - - return hook - - -def apply_hooks( - model_name: str, - module: nn.Module, - input_data: CORRECTED_INPUT_DATA_TYPE, - batch_dim: int | None, -) -> tuple[ - list[LayerInfo], - dict[int, LayerInfo], - dict[int, tuple[RemovableHandle, RemovableHandle]], -]: - """ - If input_data is provided, recursively adds hooks to all layers of the model. - Else, fills summary_list with layer info without computing a - forward pass through the network. - """ - summary_list: list[LayerInfo] = [] - layer_ids: set[int] = set() # Used to optimize is_recursive() - global_layer_info: dict[int, LayerInfo] = {} - hooks: dict[int, tuple[RemovableHandle, RemovableHandle]] = {} - stack: list[tuple[str, nn.Module, int, LayerInfo | None]] = [ - (model_name, module, 0, None) - ] - while stack: - var_name, module, curr_depth, parent_info = stack.pop() - module_id = id(module) - - # Fallback is used if the layer's pre-hook is never called, for example in - # ModuleLists or Sequentials. - global_layer_info[module_id] = LayerInfo( - var_name, module, curr_depth, parent_info - ) - pre_hook = construct_pre_hook( - global_layer_info, - summary_list, - layer_ids, - var_name, - curr_depth, - parent_info, - ) - if input_data is None or isinstance(module, WRAPPER_MODULES): - pre_hook(module, None) - else: - # Register the hook using the last layer that uses this module. - if module_id in hooks: - for hook in hooks[module_id]: - hook.remove() - hooks[module_id] = ( - module.register_forward_pre_hook(pre_hook), - module.register_forward_hook( - construct_hook(global_layer_info, batch_dim) - ), - ) - - # Replaces the equivalent recursive call by appending all of the - # subsequent the module children stack calls in the encountered order. - # Note: module.named_modules(remove_duplicate=False) doesn't work for - # some unknown reason (infinite recursion) - stack += [ - (name, mod, curr_depth + 1, global_layer_info[module_id]) - for name, mod in reversed(module._modules.items()) - if mod is not None - ] - return summary_list, global_layer_info, hooks - - -def clear_cached_forward_pass() -> None: - """Clear the forward pass cache.""" - global _cached_forward_pass - _cached_forward_pass = {} From 4099c1078f7ac19364f2ba428c0c2b2328111589 Mon Sep 17 00:00:00 2001 From: ChanifRusydi <62838892+ChanifRusydi@users.noreply.github.com> Date: Wed, 17 Sep 2025 21:11:51 +0700 Subject: [PATCH 4/5] adding type:ignore to fix mypy error --- torchinfo/torchinfo.py | 702 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 702 insertions(+) create mode 100644 torchinfo/torchinfo.py diff --git a/torchinfo/torchinfo.py b/torchinfo/torchinfo.py new file mode 100644 index 00000000..b9f2ff74 --- /dev/null +++ b/torchinfo/torchinfo.py @@ -0,0 +1,702 @@ +from __future__ import annotations + +import sys +import warnings +from typing import ( + Any, + Callable, + Iterable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Union, + cast, +) + +import numpy as np + +import torch +from torch import nn +from torch.jit import ScriptModule +from torch.return_types import mode +from torch.utils.hooks import RemovableHandle +from packaging import version + +from .enums import ColumnSettings, Mode, RowSettings, Verbosity +from .formatting import FormattingOptions +from .layer_info import LayerInfo, get_children_layers, prod +from .model_statistics import ModelStatistics + +# Some modules do the computation themselves using parameters +# or the parameters of children. Treat these as layers. +# TODO: figure out a test case for this +LAYER_MODULES = (torch.nn.MultiheadAttention,) +# These modules are not recorded during a forward pass. Handle them separately. +WRAPPER_MODULES = (ScriptModule,) + +INPUT_DATA_TYPE = Union[ + torch.Tensor, np.ndarray, Sequence[Any], Mapping[str, Any] # type: ignore[type-arg] +] +CORRECTED_INPUT_DATA_TYPE = Optional[Union[Iterable[Any], Mapping[Any, Any]]] +INPUT_SIZE_TYPE = Sequence[Union[int, Sequence[Any], torch.Size]] +CORRECTED_INPUT_SIZE_TYPE = List[Union[Sequence[Any], torch.Size]] + +DEFAULT_COLUMN_NAMES = (ColumnSettings.OUTPUT_SIZE, ColumnSettings.NUM_PARAMS) +DEFAULT_ROW_SETTINGS = {RowSettings.DEPTH} +REQUIRES_INPUT = { + ColumnSettings.INPUT_SIZE, + ColumnSettings.OUTPUT_SIZE, + ColumnSettings.MULT_ADDS, +} + +_cached_forward_pass: dict[str, list[LayerInfo]] = {} + + +def summary( + model: nn.Module, + input_size: INPUT_SIZE_TYPE | None = None, + input_data: INPUT_DATA_TYPE | None = None, + batch_dim: int | None = None, + cache_forward_pass: bool | None = None, + col_names: Iterable[str] | None = None, + col_width: int = 25, + depth: int = 3, + device: torch.device | str | None = None, + dtypes: list[torch.dtype] | None = None, + mode: str = "same", + row_settings: Iterable[str] | None = None, + verbose: int | None = None, + **kwargs: Any, +) -> ModelStatistics: + """ + Summarize the given PyTorch model. Summarized information includes: + 1) Layer names, + 2) input/output shapes, + 3) kernel shape, + 4) # of parameters, + 5) # of operations (Mult-Adds), + 6) whether layer is trainable + + NOTE: If neither input_data or input_size are provided, no forward pass through the + network is performed, and the provided model information is limited to layer names. + + Args: + model (nn.Module): + PyTorch model to summarize. The model should be fully in either train() + or eval() mode. If layers are not all in the same mode, running summary + may have side effects on batchnorm or dropout statistics. If you + encounter an issue with this, please open a GitHub issue. + + input_size (Sequence of Sizes): + Shape of input data as a List/Tuple/torch.Size + (dtypes must match model input, default is FloatTensors). + You should include batch size in the tuple. + Default: None + + input_data (Sequence of Tensors): + Arguments for the model's forward pass (dtypes inferred). + If the forward() function takes several parameters, pass in a list of + args or a dict of kwargs (if your forward() function takes in a dict + as its only argument, wrap it in a list). + Default: None + + batch_dim (int): + Batch_dimension of input data. If batch_dim is None, assume + input_data / input_size contains the batch dimension, which is used + in all calculations. Else, expand all tensors to contain the batch_dim. + Specifying batch_dim can be an runtime optimization, since if batch_dim + is specified, torchinfo uses a batch size of 1 for the forward pass. + Default: None + + cache_forward_pass (bool): + If True, cache the run of the forward() function using the model + class name as the key. If the forward pass is an expensive operation, + this can make it easier to modify the formatting of your model + summary, e.g. changing the depth or enabled column types, especially + in Jupyter Notebooks. + WARNING: Modifying the model architecture or input data/input size when + this feature is enabled does not invalidate the cache or re-run the + forward pass, and can cause incorrect summaries as a result. + Default: False + + col_names (Iterable[str]): + Specify which columns to show in the output. Currently supported: ( + "input_size", + "output_size", + "num_params", + "params_percent", + "kernel_size", + "groups", + "mult_adds", + "trainable", + ) + Default: ("output_size", "num_params") + If input_data / input_size are not provided, only "num_params" is used. + + col_width (int): + Width of each column. + Default: 25 + + depth (int): + Depth of nested layers to display (e.g. Sequentials). + Nested layers below this depth will not be displayed in the summary. + Default: 3 + + device (torch.Device): + Uses this torch device for model and input_data. + If not specified, uses the dtype of input_data if given, or the + parameters of the model. Otherwise, uses the result of + torch.cuda.is_available(). + Default: None + + dtypes (List[torch.dtype]): + If you use input_size, torchinfo assumes your input uses FloatTensors. + If your model use a different data type, specify that dtype. + For multiple inputs, specify the size of both inputs, and + also specify the types of each parameter here. + Default: None + + mode (str) + Either "train", "eval" or "same", which determines whether we call + model.train() or model.eval() before calling summary(). In any case, + original model mode is restored at the end. + Default: "same". + + row_settings (Iterable[str]): + Specify which features to show in a row. Currently supported: ( + "ascii_only", + "depth", + "var_names", + ) + Default: ("depth",) + + verbose (int): + 0 (quiet): No output + 1 (default): Print model summary + 2 (verbose): Show weight and bias layers in full detail + Default: 1 + If using a Juypter Notebook or Google Colab, the default is 0. + + **kwargs: + Other arguments used in `model.forward` function. Passing *args is no + longer supported. + + Return: + ModelStatistics object + See torchinfo/model_statistics.py for more information. + """ + input_data_specified = input_data is not None or input_size is not None + if col_names is None: + columns = ( + DEFAULT_COLUMN_NAMES + if input_data_specified + else (ColumnSettings.NUM_PARAMS,) + ) + else: + columns = tuple(ColumnSettings(name) for name in col_names) + + if row_settings is None: + rows = DEFAULT_ROW_SETTINGS + else: + rows = {RowSettings(name) for name in row_settings} + + model_mode = Mode(mode) + + if verbose is None: + verbose = 0 if hasattr(sys, "ps1") and sys.ps1 else 1 + + if cache_forward_pass is None: + # In the future, this may be enabled by default in Jupyter Notebooks + cache_forward_pass = False + + if device is None: + device = get_device(model, input_data) + elif isinstance(device, str): + device = torch.device(device) + + validate_user_params( + input_data, input_size, columns, col_width, device, dtypes, verbose + ) + + x, correct_input_size = process_input( + input_data, input_size, batch_dim, device, dtypes + ) + summary_list = forward_pass( + model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs + ) + formatting = FormattingOptions(depth, verbose, columns, col_width, rows) + results = ModelStatistics( + summary_list, correct_input_size, get_total_memory_used(x), formatting + ) + if verbose > Verbosity.QUIET: + print(results) + return results + + +def process_input( + input_data: INPUT_DATA_TYPE | None, + input_size: INPUT_SIZE_TYPE | None, + batch_dim: int | None, + device: torch.device | None, + dtypes: list[torch.dtype] | None = None, +) -> tuple[CORRECTED_INPUT_DATA_TYPE, Any]: + """Reads sample input data to get the input size.""" + x = None + correct_input_size = [] + if input_data is not None: + correct_input_size = get_input_data_sizes(input_data) + x = set_device(input_data, device) + if isinstance(x, (torch.Tensor, np.ndarray)): + x = [x] + + if input_size is not None: + assert device is not None + if dtypes is None: + dtypes = [torch.float] * len(input_size) + correct_input_size = get_correct_input_sizes(input_size) + x = get_input_tensor(correct_input_size, batch_dim, dtypes, device) + return x, correct_input_size + + +def forward_pass( + model: nn.Module, + x: CORRECTED_INPUT_DATA_TYPE, + batch_dim: int | None, + cache_forward_pass: bool, + device: torch.device | None, + mode: Mode, + **kwargs: Any, +) -> list[LayerInfo]: + """Perform a forward pass on the model using forward hooks.""" + global _cached_forward_pass + model_name = model.__class__.__name__ + if cache_forward_pass and model_name in _cached_forward_pass: + return _cached_forward_pass[model_name] + + summary_list, _, hooks = apply_hooks(model_name, model, x, batch_dim) + if x is None: + set_children_layers(summary_list) + return summary_list + + kwargs = set_device(kwargs, device) + saved_model_mode = model.training + try: + if mode == Mode.TRAIN: + model.train() + elif mode == Mode.EVAL: + model.eval() + elif mode != Mode.SAME: + raise RuntimeError( + f"Specified model mode ({list(Mode)}) not recognized: {mode}" + ) + + with torch.no_grad(): + model = model if device is None else model.to(device) + if isinstance(x, (list, tuple)): + _ = model(*x, **kwargs) + elif isinstance(x, dict): + _ = model(**x, **kwargs) + else: + # Should not reach this point, since process_input_data ensures + # x is either a list, tuple, or dict + raise ValueError("Unknown input type") + except Exception as e: + executed_layers = [layer for layer in summary_list if layer.executed] + raise RuntimeError( + "Failed to run torchinfo. See above stack traces for more details. " + f"Executed layers up to: {executed_layers}" + ) from e + finally: + if hooks: + for pre_hook, hook in hooks.values(): + pre_hook.remove() + hook.remove() + model.train(saved_model_mode) + + add_missing_container_layers(summary_list) + set_children_layers(summary_list) + + _cached_forward_pass[model_name] = summary_list + return summary_list + + +def set_children_layers(summary_list: list[LayerInfo]) -> None: + """Populates the children and depth_index fields of all LayerInfo.""" + idx: dict[int, int] = {} + for i, layer in enumerate(summary_list): + idx[layer.depth] = idx.get(layer.depth, 0) + 1 + layer.depth_index = idx[layer.depth] + layer.children = get_children_layers(summary_list, i) + + +def add_missing_container_layers(summary_list: list[LayerInfo]) -> None: + """Finds container modules not in the currently listed hierarchy.""" + layer_ids = {layer.layer_id for layer in summary_list} + current_hierarchy: dict[int, LayerInfo] = {} + for idx, layer_info in enumerate(summary_list): + # to keep track index of current layer + # after inserting new layers + rel_idx = 0 + + # create full hierarchy of current layer + hierarchy = {} + parent = layer_info.parent_info + while parent is not None and parent.depth > 0: + hierarchy[parent.depth] = parent + parent = parent.parent_info + + # show hierarchy if it is not there already + for d in range(1, layer_info.depth): + if ( + d not in current_hierarchy + or current_hierarchy[d].module is not hierarchy[d].module + ) and hierarchy[d] is not summary_list[idx + rel_idx - 1]: + hierarchy[d].calculate_num_params() + hierarchy[d].check_recursive(layer_ids) + summary_list.insert(idx + rel_idx, hierarchy[d]) + layer_ids.add(hierarchy[d].layer_id) + + current_hierarchy[d] = hierarchy[d] + rel_idx += 1 + + current_hierarchy[layer_info.depth] = layer_info + + # remove deeper hierarchy + d = layer_info.depth + 1 + while d in current_hierarchy: + current_hierarchy.pop(d) + d += 1 + + +def validate_user_params( + input_data: INPUT_DATA_TYPE | None, + input_size: INPUT_SIZE_TYPE | None, + col_names: tuple[ColumnSettings, ...], + col_width: int, + device: torch.device | None, + dtypes: list[torch.dtype] | None, + verbose: int, +) -> None: + """Raise exceptions if the user's input is invalid.""" + if col_width <= 0: + raise ValueError(f"Column width must be greater than 0: col_width={col_width}") + if verbose not in (0, 1, 2): + raise ValueError( + "Verbose must be either 0 (quiet), 1 (default), or 2 (verbose)." + ) + both_input_specified = input_data is not None and input_size is not None + if both_input_specified: + raise RuntimeError("Only one of (input_data, input_size) should be specified.") + + neither_input_specified = input_data is None and input_size is None + not_allowed = set(col_names) & REQUIRES_INPUT + if neither_input_specified and not_allowed: + raise ValueError( + "You must pass input_data or input_size in order " + f"to use columns: {not_allowed}" + ) + + if dtypes is not None and any( + dtype in (torch.float16, torch.bfloat16) for dtype in dtypes + ): + if input_size is not None: + warnings.warn( + "Half precision is not supported with input_size parameter, and may " + "output incorrect results. Try passing input_data directly.", + stacklevel=2, + ) + if device is not None and device.type == "cpu": + warnings.warn( + "Half precision is not supported on cpu. Set the `device` field or " + "pass `input_data` using the correct device.", + stacklevel=2, + ) + + +def traverse_input_data( + data: Any, action_fn: Callable[..., Any], aggregate_fn: Callable[..., Any] +) -> Any: + """ + Traverses any type of nested input data. On a tensor, returns the action given by + action_fn, and afterwards aggregates the results using aggregate_fn. + """ + if isinstance(data, torch.Tensor): + result = action_fn(data) + elif isinstance(data, np.ndarray): + result = action_fn(torch.from_numpy(data)) + # If the result of action_fn is a torch.Tensor, then action_fn was meant for + # torch.Tensors only (like calling .to(...)) -> Ignore. + if isinstance(result, torch.Tensor): + result = data + + # Recursively apply to collection items + elif isinstance(data, Mapping): + aggregate = aggregate_fn(data) + result = aggregate( + { + k: traverse_input_data(v, action_fn, aggregate_fn) + for k, v in data.items() + } + ) + elif isinstance(data, tuple) and hasattr(data, "_fields"): # Named tuple + aggregate = aggregate_fn(data) + result = aggregate( + *(traverse_input_data(d, action_fn, aggregate_fn) for d in data) + ) + elif isinstance(data, Iterable) and not isinstance(data, str): + aggregate = aggregate_fn(data) + result = aggregate( + [traverse_input_data(d, action_fn, aggregate_fn) for d in data] + ) + else: + # Data is neither a tensor nor a collection + result = data + return result + + +def set_device(data: Any, device: torch.device | None) -> Any: + """Sets device for all input types and collections of input types.""" + return ( + data + if device is None + else traverse_input_data( + data, + action_fn=lambda data: data.to(device, non_blocking=True), + aggregate_fn=type, + ) + ) + + +def get_device( + model: nn.Module, input_data: INPUT_DATA_TYPE | None +) -> torch.device | None: + """ + If input_data is given, the device should not be changed + (to allow for multi-device models, etc.) + + Otherwise gets device of first parameter of model and returns it. + As #371 mentioned, in system with other accelerator (such as MPS), + returns the accelerator (introduced in pytorch 2.6.0) device if available, + otherwise returns cuda if cuda is available, if not then return cpu. + + Attempting to also add DirectML device (may include AMD GPU support). + However since personally didn't have any AMD GPU at the moment, + the implementation is solely based on documentation and examples. + DirectML documentation and example: + https://github.com/microsoft/DirectML + https://medium.com/@ochwada/preparations-for-pytorch-and-directml-on-amd-ryzen-9-6950h-for-ai-projects-15c164d22332 + """ + if input_data is None: + try: + model_parameter = next(model.parameters()) + except StopIteration: + model_parameter = None + + if model_parameter is not None and model_parameter.device: + return model_parameter.device + # Since torch.accelerator is available in torch 2.6.0 and above + if version.parse(torch.__version__) >= version.parse("2.6.0"): + try: + accelerator_module = getattr(torch, "accelerator", None) + if accelerator_module is not None: + if accelerator_module.is_available(): + return accelerator_module.current_accelerator() #type: ignore[no-any-return] + except Exception as e: + print(f"Error getting accelerator device: {e}") + try: + import torch_directml #type: ignore[import-not-found] + print("Pytorch DirectML is installed, using DirectML device.") + if model_parameter is not None: + return torch_directml.device(torch_directml.default_device()) #type: ignore[no-any-return] + except ImportError: + print("Pytorch DirectML is not installed.") + return torch.device("cpu") + return None + + +def get_input_data_sizes(data: Any) -> Any: + """ + Converts input data to an equivalent data structure of torch.Sizes + instead of tensors. + """ + return traverse_input_data( + data, action_fn=lambda data: data.size(), aggregate_fn=type + ) + + +def get_total_memory_used(data: CORRECTED_INPUT_DATA_TYPE) -> int: + """Calculates the total memory of all tensors stored in data.""" + result = traverse_input_data( + data, + action_fn=lambda data: sys.getsizeof( + data.untyped_storage() + if hasattr(data, "untyped_storage") + else data.storage() + ), + aggregate_fn=( + # We don't need the dictionary keys in this case + lambda data: (lambda d: sum(d.values())) + if isinstance(data, Mapping) + else sum + ), + ) + return cast(int, result) + + +def get_input_tensor( + input_size: CORRECTED_INPUT_SIZE_TYPE, + batch_dim: int | None, + dtypes: list[torch.dtype], + device: torch.device, +) -> list[torch.Tensor]: + """Get input_tensor with batch size 1 for use in model.forward()""" + x = [] + for size, dtype in zip(input_size, dtypes): + input_tensor = torch.rand(*size) + if batch_dim is not None: + input_tensor = input_tensor.unsqueeze(dim=batch_dim) + x.append(input_tensor.to(device).type(dtype)) + return x + + +def flatten(nested_array: INPUT_SIZE_TYPE) -> Iterator[Any]: + """Flattens a nested array.""" + for item in nested_array: + if isinstance(item, (list, tuple)): + yield from flatten(item) + else: + yield item + + +def get_correct_input_sizes(input_size: INPUT_SIZE_TYPE) -> CORRECTED_INPUT_SIZE_TYPE: + """ + Convert input_size to the correct form, which is a list of tuples. + Also handles multiple inputs to the network. + """ + if not isinstance(input_size, (list, tuple)): + raise TypeError( + "Input_size is not a recognized type. Please ensure input_size is valid.\n" + "For multiple inputs to the network, ensure input_size is a list of tuple " + "sizes. If you are having trouble here, please submit a GitHub issue." + ) + if not input_size or any(size <= 0 for size in flatten(input_size)): + raise ValueError("Input_data is invalid, or negative size found in input_data.") + + if isinstance(input_size, list) and isinstance(input_size[0], int): + return [tuple(input_size)] + if isinstance(input_size, list): + return input_size + if isinstance(input_size, tuple) and isinstance(input_size[0], tuple): + return list(input_size) + return [input_size] + + +def construct_pre_hook( + global_layer_info: dict[int, LayerInfo], + summary_list: list[LayerInfo], + layer_ids: set[int], + var_name: str, + curr_depth: int, + parent_info: LayerInfo | None, +) -> Callable[[nn.Module, Any], None]: + def pre_hook(module: nn.Module, inputs: Any) -> None: + """Create a LayerInfo object to aggregate layer information.""" + del inputs + info = LayerInfo(var_name, module, curr_depth, parent_info) + info.calculate_num_params() + info.check_recursive(layer_ids) + summary_list.append(info) + layer_ids.add(info.layer_id) + global_layer_info[info.layer_id] = info + + return pre_hook + + +def construct_hook( + global_layer_info: dict[int, LayerInfo], batch_dim: int | None +) -> Callable[[nn.Module, Any, Any], None]: + def hook(module: nn.Module, inputs: Any, outputs: Any) -> None: + """Update LayerInfo after forward pass.""" + info = global_layer_info[id(module)] + if info.contains_lazy_param: + info.calculate_num_params() + info.input_size, _ = info.calculate_size(inputs, batch_dim) + info.output_size, elem_bytes = info.calculate_size(outputs, batch_dim) + info.output_bytes = elem_bytes * prod(info.output_size) + info.executed = True + info.calculate_macs() + + return hook + + +def apply_hooks( + model_name: str, + module: nn.Module, + input_data: CORRECTED_INPUT_DATA_TYPE, + batch_dim: int | None, +) -> tuple[ + list[LayerInfo], + dict[int, LayerInfo], + dict[int, tuple[RemovableHandle, RemovableHandle]], +]: + """ + If input_data is provided, recursively adds hooks to all layers of the model. + Else, fills summary_list with layer info without computing a + forward pass through the network. + """ + summary_list: list[LayerInfo] = [] + layer_ids: set[int] = set() # Used to optimize is_recursive() + global_layer_info: dict[int, LayerInfo] = {} + hooks: dict[int, tuple[RemovableHandle, RemovableHandle]] = {} + stack: list[tuple[str, nn.Module, int, LayerInfo | None]] = [ + (model_name, module, 0, None) + ] + while stack: + var_name, module, curr_depth, parent_info = stack.pop() + module_id = id(module) + + # Fallback is used if the layer's pre-hook is never called, for example in + # ModuleLists or Sequentials. + global_layer_info[module_id] = LayerInfo( + var_name, module, curr_depth, parent_info + ) + pre_hook = construct_pre_hook( + global_layer_info, + summary_list, + layer_ids, + var_name, + curr_depth, + parent_info, + ) + if input_data is None or isinstance(module, WRAPPER_MODULES): + pre_hook(module, None) + else: + # Register the hook using the last layer that uses this module. + if module_id in hooks: + for hook in hooks[module_id]: + hook.remove() + hooks[module_id] = ( + module.register_forward_pre_hook(pre_hook), + module.register_forward_hook( + construct_hook(global_layer_info, batch_dim) + ), + ) + + # Replaces the equivalent recursive call by appending all of the + # subsequent the module children stack calls in the encountered order. + # Note: module.named_modules(remove_duplicate=False) doesn't work for + # some unknown reason (infinite recursion) + stack += [ + (name, mod, curr_depth + 1, global_layer_info[module_id]) + for name, mod in reversed(module._modules.items()) + if mod is not None + ] + return summary_list, global_layer_info, hooks + + +def clear_cached_forward_pass() -> None: + """Clear the forward pass cache.""" + global _cached_forward_pass + _cached_forward_pass = {} From 6332edb491ca4cce09577347d486a044eda3d8bc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Sep 2025 14:12:07 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchinfo/torchinfo.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/torchinfo/torchinfo.py b/torchinfo/torchinfo.py index b9f2ff74..e54ba68c 100644 --- a/torchinfo/torchinfo.py +++ b/torchinfo/torchinfo.py @@ -16,13 +16,11 @@ ) import numpy as np - import torch +from packaging import version from torch import nn from torch.jit import ScriptModule -from torch.return_types import mode from torch.utils.hooks import RemovableHandle -from packaging import version from .enums import ColumnSettings, Mode, RowSettings, Verbosity from .formatting import FormattingOptions @@ -502,14 +500,15 @@ def get_device( accelerator_module = getattr(torch, "accelerator", None) if accelerator_module is not None: if accelerator_module.is_available(): - return accelerator_module.current_accelerator() #type: ignore[no-any-return] + return accelerator_module.current_accelerator() # type: ignore[no-any-return] except Exception as e: print(f"Error getting accelerator device: {e}") try: - import torch_directml #type: ignore[import-not-found] + import torch_directml # type: ignore[import-not-found] + print("Pytorch DirectML is installed, using DirectML device.") if model_parameter is not None: - return torch_directml.device(torch_directml.default_device()) #type: ignore[no-any-return] + return torch_directml.device(torch_directml.default_device()) # type: ignore[no-any-return] except ImportError: print("Pytorch DirectML is not installed.") return torch.device("cpu")