Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions paddle/phi/kernels/gpu/llm_int8_linear_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,10 @@ PD_REGISTER_KERNEL(llm_int8_linear,
GPU,
ALL_LAYOUT,
phi::LLMInt8LinearKernel,
phi::float16,
phi::bfloat16) {}
phi::float16
#if defined(PADDLE_WITH_CUDA)
,
phi::bfloat16
#endif
) {
}
54 changes: 26 additions & 28 deletions test/quantization/test_llm_int8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,17 @@
from paddle.framework import set_default_dtype


def llm_int8_linear_not_supported():
return (
not core.is_compiled_with_cuda()
or paddle.is_compiled_with_rocm()
or paddle.device.cuda.get_device_capability()[0] < 8
)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA_ARCH >= 8",
llm_int8_linear_not_supported(),
"llm_int8_linear requires CUDA and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase(unittest.TestCase):
def config(self):
Expand Down Expand Up @@ -194,9 +201,8 @@ def test_llm_int8_linear(self):


@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA_ARCH >= 8",
llm_int8_linear_not_supported(),
"llm_int8_linear requires CUDA and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase1(LLMInt8LinearTestCase):
def config(self):
Expand All @@ -206,9 +212,8 @@ def config(self):


@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA_ARCH >= 8",
llm_int8_linear_not_supported(),
"llm_int8_linear requires CUDA and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase2(LLMInt8LinearTestCase):
def config(self):
Expand All @@ -219,10 +224,9 @@ def config(self):


@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 8
llm_int8_linear_not_supported()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"quantized_matmul requires CUDA_ARCH >= 8 or core is not support bfloat16",
"llm_int8_linear requires CUDA, CUDA_ARCH >= 8 and bfloat16 support",
)
class LLMInt8LinearTestCase4(LLMInt8LinearTestCase):
def config(self):
Expand All @@ -232,9 +236,8 @@ def config(self):


@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA_ARCH >= 8",
llm_int8_linear_not_supported(),
"llm_int8_linear requires CUDA and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase5(LLMInt8LinearTestCase):
def config(self):
Expand All @@ -245,9 +248,8 @@ def config(self):


@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA_ARCH >= 8",
llm_int8_linear_not_supported(),
"llm_int8_linear requires CUDA and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase7(LLMInt8LinearTestCase):
def config(self):
Expand All @@ -259,9 +261,8 @@ def config(self):


@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA_ARCH >= 8",
llm_int8_linear_not_supported(),
"llm_int8_linear requires CUDA and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase8(LLMInt8LinearTestCase):
def config(self):
Expand All @@ -274,9 +275,8 @@ def config(self):


@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA_ARCH >= 8",
llm_int8_linear_not_supported(),
"llm_int8_linear requires CUDA and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCase10(LLMInt8LinearTestCase):
def config(self):
Expand All @@ -289,10 +289,8 @@ def config(self):


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA_ARCH >= 8",
llm_int8_linear_not_supported(),
"llm_int8_linear requires CUDA and CUDA_ARCH >= 8",
)
class LLMInt8LinearTestCaseStatic(LLMInt8LinearTestCase):
def config(self):
Expand Down