Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions torchinfo/torchinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import torch
from packaging import version
from torch import nn
from torch.jit import ScriptModule
from torch.utils.hooks import RemovableHandle
Expand Down Expand Up @@ -473,18 +474,44 @@ def get_device(
If input_data is given, the device should not be changed
(to allow for multi-device models, etc.)

Otherwise gets device of first parameter of model and returns it if it is on cuda,
otherwise returns cuda if available or cpu if not.
Otherwise gets device of first parameter of model and returns it.
As #371 mentioned, in system with other accelerator (such as MPS),
returns the accelerator (introduced in pytorch 2.6.0) device if available,
otherwise returns cuda if cuda is available, if not then return cpu.

Attempting to also add DirectML device (may include AMD GPU support).
However since personally didn't have any AMD GPU at the moment,
the implementation is solely based on documentation and examples.
DirectML documentation and example:
https://github.com/microsoft/DirectML
https://medium.com/@ochwada/preparations-for-pytorch-and-directml-on-amd-ryzen-9-6950h-for-ai-projects-15c164d22332
"""
if input_data is None:
try:
model_parameter = next(model.parameters())
except StopIteration:
model_parameter = None

if model_parameter is not None and model_parameter.is_cuda:
if model_parameter is not None and model_parameter.device:
return model_parameter.device
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Since torch.accelerator is available in torch 2.6.0 and above
if version.parse(torch.__version__) >= version.parse("2.6.0"):
try:
accelerator_module = getattr(torch, "accelerator", None)
if accelerator_module is not None:
if accelerator_module.is_available():
return accelerator_module.current_accelerator() # type: ignore[no-any-return]
except Exception as e:
print(f"Error getting accelerator device: {e}")
try:
import torch_directml # type: ignore[import-not-found]

print("Pytorch DirectML is installed, using DirectML device.")
if model_parameter is not None:
return torch_directml.device(torch_directml.default_device()) # type: ignore[no-any-return]
except ImportError:
print("Pytorch DirectML is not installed.")
return torch.device("cpu")
return None


Expand Down
Loading