From 66a48d87a0ec756858a705727eb93b2fcb3b72ea Mon Sep 17 00:00:00 2001 From: austin1997 Date: Tue, 23 Jun 2026 05:49:25 +0000 Subject: [PATCH] [ROCm] Avoid unsupported BF16 llm_int8 registration --- .../phi/kernels/gpu/llm_int8_linear_kernel.cu | 9 +++- test/quantization/test_llm_int8_linear.py | 54 +++++++++---------- 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/paddle/phi/kernels/gpu/llm_int8_linear_kernel.cu b/paddle/phi/kernels/gpu/llm_int8_linear_kernel.cu index 27e51c678274f..056585ff653a2 100644 --- a/paddle/phi/kernels/gpu/llm_int8_linear_kernel.cu +++ b/paddle/phi/kernels/gpu/llm_int8_linear_kernel.cu @@ -81,5 +81,10 @@ PD_REGISTER_KERNEL(llm_int8_linear, GPU, ALL_LAYOUT, phi::LLMInt8LinearKernel, - phi::float16, - phi::bfloat16) {} + phi::float16 +#if defined(PADDLE_WITH_CUDA) + , + phi::bfloat16 +#endif +) { +} diff --git a/test/quantization/test_llm_int8_linear.py b/test/quantization/test_llm_int8_linear.py index b1fdb0e3eba5b..c9bb146e4b435 100644 --- a/test/quantization/test_llm_int8_linear.py +++ b/test/quantization/test_llm_int8_linear.py @@ -24,10 +24,17 @@ from paddle.framework import set_default_dtype +def llm_int8_linear_not_supported(): + return ( + not core.is_compiled_with_cuda() + or paddle.is_compiled_with_rocm() + or paddle.device.cuda.get_device_capability()[0] < 8 + ) + + @unittest.skipIf( - not core.is_compiled_with_cuda() - or paddle.device.cuda.get_device_capability()[0] < 8, - "quantized_matmul requires CUDA_ARCH >= 8", + llm_int8_linear_not_supported(), + "llm_int8_linear requires CUDA and CUDA_ARCH >= 8", ) class LLMInt8LinearTestCase(unittest.TestCase): def config(self): @@ -194,9 +201,8 @@ def test_llm_int8_linear(self): @unittest.skipIf( - not core.is_compiled_with_cuda() - or paddle.device.cuda.get_device_capability()[0] < 8, - "quantized_matmul requires CUDA_ARCH >= 8", + llm_int8_linear_not_supported(), + "llm_int8_linear requires CUDA and CUDA_ARCH >= 8", ) class LLMInt8LinearTestCase1(LLMInt8LinearTestCase): def config(self): @@ -206,9 +212,8 @@ def config(self): @unittest.skipIf( - not core.is_compiled_with_cuda() - or paddle.device.cuda.get_device_capability()[0] < 8, - "quantized_matmul requires CUDA_ARCH >= 8", + llm_int8_linear_not_supported(), + "llm_int8_linear requires CUDA and CUDA_ARCH >= 8", ) class LLMInt8LinearTestCase2(LLMInt8LinearTestCase): def config(self): @@ -219,10 +224,9 @@ def config(self): @unittest.skipIf( - not core.is_compiled_with_cuda() - or paddle.device.cuda.get_device_capability()[0] < 8 + llm_int8_linear_not_supported() or not core.is_bfloat16_supported(core.CUDAPlace(0)), - "quantized_matmul requires CUDA_ARCH >= 8 or core is not support bfloat16", + "llm_int8_linear requires CUDA, CUDA_ARCH >= 8 and bfloat16 support", ) class LLMInt8LinearTestCase4(LLMInt8LinearTestCase): def config(self): @@ -232,9 +236,8 @@ def config(self): @unittest.skipIf( - not core.is_compiled_with_cuda() - or paddle.device.cuda.get_device_capability()[0] < 8, - "quantized_matmul requires CUDA_ARCH >= 8", + llm_int8_linear_not_supported(), + "llm_int8_linear requires CUDA and CUDA_ARCH >= 8", ) class LLMInt8LinearTestCase5(LLMInt8LinearTestCase): def config(self): @@ -245,9 +248,8 @@ def config(self): @unittest.skipIf( - not core.is_compiled_with_cuda() - or paddle.device.cuda.get_device_capability()[0] < 8, - "quantized_matmul requires CUDA_ARCH >= 8", + llm_int8_linear_not_supported(), + "llm_int8_linear requires CUDA and CUDA_ARCH >= 8", ) class LLMInt8LinearTestCase7(LLMInt8LinearTestCase): def config(self): @@ -259,9 +261,8 @@ def config(self): @unittest.skipIf( - not core.is_compiled_with_cuda() - or paddle.device.cuda.get_device_capability()[0] < 8, - "quantized_matmul requires CUDA_ARCH >= 8", + llm_int8_linear_not_supported(), + "llm_int8_linear requires CUDA and CUDA_ARCH >= 8", ) class LLMInt8LinearTestCase8(LLMInt8LinearTestCase): def config(self): @@ -274,9 +275,8 @@ def config(self): @unittest.skipIf( - not core.is_compiled_with_cuda() - or paddle.device.cuda.get_device_capability()[0] < 8, - "quantized_matmul requires CUDA_ARCH >= 8", + llm_int8_linear_not_supported(), + "llm_int8_linear requires CUDA and CUDA_ARCH >= 8", ) class LLMInt8LinearTestCase10(LLMInt8LinearTestCase): def config(self): @@ -289,10 +289,8 @@ def config(self): @unittest.skipIf( - not core.is_compiled_with_cuda() - or not core.is_compiled_with_cuda() - or paddle.device.cuda.get_device_capability()[0] < 8, - "quantized_matmul requires CUDA_ARCH >= 8", + llm_int8_linear_not_supported(), + "llm_int8_linear requires CUDA and CUDA_ARCH >= 8", ) class LLMInt8LinearTestCaseStatic(LLMInt8LinearTestCase): def config(self):