From a93e8866cd77b0e6bfc3a0a78cc1ac0683b8c329 Mon Sep 17 00:00:00 2001 From: wenqf11 Date: Fri, 13 Feb 2026 14:54:51 +0800 Subject: [PATCH] patch validate_device --- src/torchada/_patch.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/torchada/_patch.py b/src/torchada/_patch.py index bc689a5..0a7c1cc 100644 --- a/src/torchada/_patch.py +++ b/src/torchada/_patch.py @@ -1176,6 +1176,23 @@ def _patch_autotune_process(): autotune_process.CUDA_VISIBLE_DEVICES = "MUSA_VISIBLE_DEVICES" +@patch_function +@requires_import("torch_musa") +def _patch_validate_device(): + import torch.nn.attention.flex_attention + from torch import Tensor + def _validate_device(query: Tensor, key: Tensor, value: Tensor): + """TODO: Remove once non cuda/cpu devices support is added + We only need to check query since we have already that q,k,v are on the same device + """ + if query.device.type != "cuda" and query.device.type != "cpu" and query.device.type != "musa": + raise ValueError( + "FlexAttention is only supported on CUDA or CPU or MUSA devices. " + f"Found input tensors on {query.device.type} device." + ) + torch.nn.attention.flex_attention._validate_device = _validate_device + + class _CDLLWrapper: """ Wrapper for ctypes.CDLL that automatically translates CUDA/NCCL function names