Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/torchada/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,23 @@ def _patch_autotune_process():
autotune_process.CUDA_VISIBLE_DEVICES = "MUSA_VISIBLE_DEVICES"


@patch_function
@requires_import("torch_musa")
def _patch_validate_device():
import torch.nn.attention.flex_attention
from torch import Tensor
def _validate_device(query: Tensor, key: Tensor, value: Tensor):
"""TODO: Remove once non cuda/cpu devices support is added
We only need to check query since we have already that q,k,v are on the same device
"""
if query.device.type != "cuda" and query.device.type != "cpu" and query.device.type != "musa":
raise ValueError(
"FlexAttention is only supported on CUDA or CPU or MUSA devices. "
f"Found input tensors on {query.device.type} device."
)
torch.nn.attention.flex_attention._validate_device = _validate_device


class _CDLLWrapper:
"""
Wrapper for ctypes.CDLL that automatically translates CUDA/NCCL function names
Expand Down