diff --git a/torchinfo/torchinfo.py b/torchinfo/torchinfo.py index c0b6bb5b..e54ba68c 100644 --- a/torchinfo/torchinfo.py +++ b/torchinfo/torchinfo.py @@ -17,6 +17,7 @@ import numpy as np import torch +from packaging import version from torch import nn from torch.jit import ScriptModule from torch.utils.hooks import RemovableHandle @@ -473,8 +474,17 @@ def get_device( If input_data is given, the device should not be changed (to allow for multi-device models, etc.) - Otherwise gets device of first parameter of model and returns it if it is on cuda, - otherwise returns cuda if available or cpu if not. + Otherwise gets device of first parameter of model and returns it. + As #371 mentioned, in system with other accelerator (such as MPS), + returns the accelerator (introduced in pytorch 2.6.0) device if available, + otherwise returns cuda if cuda is available, if not then return cpu. + + Attempting to also add DirectML device (may include AMD GPU support). + However since personally didn't have any AMD GPU at the moment, + the implementation is solely based on documentation and examples. + DirectML documentation and example: + https://github.com/microsoft/DirectML + https://medium.com/@ochwada/preparations-for-pytorch-and-directml-on-amd-ryzen-9-6950h-for-ai-projects-15c164d22332 """ if input_data is None: try: @@ -482,9 +492,26 @@ def get_device( except StopIteration: model_parameter = None - if model_parameter is not None and model_parameter.is_cuda: + if model_parameter is not None and model_parameter.device: return model_parameter.device - return torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Since torch.accelerator is available in torch 2.6.0 and above + if version.parse(torch.__version__) >= version.parse("2.6.0"): + try: + accelerator_module = getattr(torch, "accelerator", None) + if accelerator_module is not None: + if accelerator_module.is_available(): + return accelerator_module.current_accelerator() # type: ignore[no-any-return] + except Exception as e: + print(f"Error getting accelerator device: {e}") + try: + import torch_directml # type: ignore[import-not-found] + + print("Pytorch DirectML is installed, using DirectML device.") + if model_parameter is not None: + return torch_directml.device(torch_directml.default_device()) # type: ignore[no-any-return] + except ImportError: + print("Pytorch DirectML is not installed.") + return torch.device("cpu") return None