diff --git a/src/ntops/torch/utils.py b/src/ntops/torch/utils.py index 1e0d5ba..e9b2dde 100644 --- a/src/ntops/torch/utils.py +++ b/src/ntops/torch/utils.py @@ -6,10 +6,55 @@ import ntops +class _CachedMakeDefaultConfig: + def __init__(self, num_warps=None, num_stages=None, max_num_configs=None): + self.num_warps = num_warps + + self.num_stages = num_stages + + self.max_num_configs = max_num_configs + + +_cached_make_default_config = _CachedMakeDefaultConfig() + + +def get_default_num_warps(): + return _cached_make_default_config.num_warps + + +def set_default_num_warps(num_warps): + _cached_make_default_config.num_warps = num_warps + + +def get_default_num_stages(): + return _cached_make_default_config.num_stages + + +def set_default_num_stages(num_stages): + _cached_make_default_config.num_stages = num_stages + + +def get_default_max_num_configs(): + return _cached_make_default_config.max_num_configs + + +def set_default_max_num_configs(max_num_configs): + _cached_make_default_config.max_num_configs = max_num_configs + + @functools.cache def _cached_make( premake, *args, num_warps=None, num_stages=None, max_num_configs=None, **keywords ): + if num_warps is None: + num_warps = _cached_make_default_config.num_warps + + if num_stages is None: + num_stages = _cached_make_default_config.num_stages + + if max_num_configs is None: + max_num_configs = _cached_make_default_config.max_num_configs + return ninetoothed.make( *premake(*args, **keywords), num_warps=num_warps, diff --git a/tests/conftest.py b/tests/conftest.py index 3f29fd4..0f89217 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,11 +4,15 @@ import pytest import torch +import ntops.torch.utils + def pytest_configure(): torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False + ntops.torch.utils.set_default_max_num_configs(_DEFAULT_MAX_NUM_CONFIGS) + def pytest_collectstart(collector): if isinstance(collector, pytest.Module): @@ -25,6 +29,9 @@ def set_seed_per_test(request): _set_random_seed(_hash(_test_case_path_from_request(request))) +_DEFAULT_MAX_NUM_CONFIGS = 3 + + def _set_random_seed(seed): random.seed(seed) torch.manual_seed(seed) diff --git a/tests/test_abs.py b/tests/test_abs.py index 3ebe40c..d643cf8 100644 --- a/tests/test_abs.py +++ b/tests/test_abs.py @@ -8,12 +8,10 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_abs(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.abs(input) reference_output = torch.abs(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_add.py b/tests/test_add.py index 7e67ed3..1a34e97 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_add(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) alpha = gauss() @@ -18,4 +16,4 @@ def test_cuda(shape, dtype, atol, rtol): ninetoothed_output = ntops.torch.add(input, other, alpha=alpha) reference_output = torch.add(input, other, alpha=alpha) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_addmm.py b/tests/test_addmm.py index 8c105c2..56361d6 100644 --- a/tests/test_addmm.py +++ b/tests/test_addmm.py @@ -9,9 +9,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(m, n, k, dtype, atol, rtol): - device = "cuda" - +def test_addmm(m, n, k, dtype, device, rtol, atol): input = torch.randn((m, n), dtype=dtype, device=device) x = torch.randn((m, k), dtype=dtype, device=device) y = torch.randn((k, n), dtype=dtype, device=device) @@ -21,4 +19,4 @@ def test_cuda(m, n, k, dtype, atol, rtol): ninetoothed_output = ntops.torch.addmm(input, x, y, beta=beta, alpha=alpha) reference_output = torch.addmm(input, x, y, beta=beta, alpha=alpha) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_bitwise_and.py b/tests/test_bitwise_and.py index b286af2..77ddf95 100644 --- a/tests/test_bitwise_and.py +++ b/tests/test_bitwise_and.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments(False)) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_bitwise_and(shape, dtype, device, rtol, atol): if dtype == torch.bool: prob = 0.5 input = torch.rand(shape, dtype=torch.float32, device=device) > prob diff --git a/tests/test_bitwise_not.py b/tests/test_bitwise_not.py index 4a70a96..93e48a5 100644 --- a/tests/test_bitwise_not.py +++ b/tests/test_bitwise_not.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments(False)) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_bitwise_not(shape, dtype, device, rtol, atol): if dtype == torch.bool: prob = 0.5 input = torch.rand(shape, dtype=torch.float32, device=device) > prob diff --git a/tests/test_bitwise_or.py b/tests/test_bitwise_or.py index 7f5a4d6..2d8096e 100644 --- a/tests/test_bitwise_or.py +++ b/tests/test_bitwise_or.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments(False)) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_bitwise_or(shape, dtype, device, rtol, atol): if dtype == torch.bool: prob = 0.5 input = torch.rand(shape, dtype=torch.float32, device=device) > prob diff --git a/tests/test_bmm.py b/tests/test_bmm.py index 96337a8..2b66d9f 100644 --- a/tests/test_bmm.py +++ b/tests/test_bmm.py @@ -10,9 +10,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(m, n, k, dtype, atol, rtol): - device = "cuda" - +def test_bmm(m, n, k, dtype, device, rtol, atol): b = random.randint(4, 16) input = torch.randn((b, m, k), dtype=dtype, device=device) other = torch.randn((b, k, n), dtype=dtype, device=device) @@ -20,4 +18,4 @@ def test_cuda(m, n, k, dtype, atol, rtol): ninetoothed_output = ntops.torch.bmm(input, other) reference_output = torch.bmm(input, other) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_clamp.py b/tests/test_clamp.py index 608fd3d..4ba0de9 100644 --- a/tests/test_clamp.py +++ b/tests/test_clamp.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_clamp(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) min = torch.randn(shape, dtype=dtype, device=device) max = torch.randn(shape, dtype=dtype, device=device) @@ -18,4 +16,4 @@ def test_cuda(shape, dtype, atol, rtol): ninetoothed_output = ntops.torch.clamp(input, min, max) reference_output = torch.clamp(input, min, max) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_cos.py b/tests/test_cos.py index 005588c..d663310 100644 --- a/tests/test_cos.py +++ b/tests/test_cos.py @@ -8,16 +8,13 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): +def test_cos(shape, dtype, device, rtol, atol): # TODO: Test for `float16` later. if dtype is torch.float16: return - - device = "cuda" - input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.cos(input) reference_output = torch.cos(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_div.py b/tests/test_div.py index bf3644f..8a3878d 100644 --- a/tests/test_div.py +++ b/tests/test_div.py @@ -7,21 +7,22 @@ @skip_if_cuda_not_available +@pytest.mark.parametrize( + "rounding_mode", + [ + None, + pytest.param( + "trunc", marks=pytest.mark.skip(reason="TODO: Test for `trunc` mode later.") + ), + "floor", + ], +) @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_div(shape, rounding_mode, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) - for rounding_mode in (None, "trunc", "floor"): - # TODO: Test for `trunc` mode later. - if rounding_mode == "trunc": - continue - - ninetoothed_output = ntops.torch.div(input, other, rounding_mode=rounding_mode) - reference_output = torch.div(input, other, rounding_mode=rounding_mode) + ninetoothed_output = ntops.torch.div(input, other, rounding_mode=rounding_mode) + reference_output = torch.div(input, other, rounding_mode=rounding_mode) - assert torch.allclose( - ninetoothed_output, reference_output, atol=atol, rtol=rtol - ) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_dropout.py b/tests/test_dropout.py index 1181e79..89c08fd 100644 --- a/tests/test_dropout.py +++ b/tests/test_dropout.py @@ -11,9 +11,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_dropout(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) p = random.uniform(0, 1) diff --git a/tests/test_eq.py b/tests/test_eq.py index 6c44191..04ebe42 100644 --- a/tests/test_eq.py +++ b/tests/test_eq.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_eq(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_exp.py b/tests/test_exp.py index f0d66d3..33f7c59 100644 --- a/tests/test_exp.py +++ b/tests/test_exp.py @@ -8,16 +8,13 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): +def test_exp(shape, dtype, device, rtol, atol): # TODO: Test for `float16` later. if dtype is torch.float16: return - - device = "cuda" - input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.exp(input) reference_output = torch.exp(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_ge.py b/tests/test_ge.py index 0bf2a49..bc4a83b 100644 --- a/tests/test_ge.py +++ b/tests/test_ge.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_ge(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_gelu.py b/tests/test_gelu.py index 6395a2e..af14670 100644 --- a/tests/test_gelu.py +++ b/tests/test_gelu.py @@ -8,16 +8,20 @@ @skip_if_cuda_not_available +@pytest.mark.parametrize( + "approximate", + ( + "none", + pytest.param( + "tanh", marks=pytest.mark.skip(reason="TODO: Test for `tanh` mode later.") + ), + ), +) @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_gelu(shape, approximate, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) - for approximate in ("none", "tanh"): - ninetoothed_output = ntops.torch.gelu(input) - reference_output = F.gelu(input) + ninetoothed_output = ntops.torch.gelu(input, approximate=approximate) + reference_output = F.gelu(input, approximate=approximate) - assert torch.allclose( - ninetoothed_output, reference_output, atol=atol, rtol=rtol - ) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_gt.py b/tests/test_gt.py index 509b6e3..cfc03cc 100644 --- a/tests/test_gt.py +++ b/tests/test_gt.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_gt(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_isinf.py b/tests/test_isinf.py index b4f8d80..e651836 100644 --- a/tests/test_isinf.py +++ b/tests/test_isinf.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_isinf(shape, dtype, device, rtol, atol): def generate_inf_tensor(shape, dtype, device): x = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_isnan.py b/tests/test_isnan.py index 7f87bcc..7e29b01 100644 --- a/tests/test_isnan.py +++ b/tests/test_isnan.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_isnan(shape, dtype, device, rtol, atol): def generate_nan_tensor(shape, dtype, device): nan_prob = 0.4 prob_tensor = torch.rand(shape, device=device) diff --git a/tests/test_layer_norm.py b/tests/test_layer_norm.py index fbe2c73..3d1bd83 100644 --- a/tests/test_layer_norm.py +++ b/tests/test_layer_norm.py @@ -13,9 +13,9 @@ @pytest.mark.parametrize("bias_is_none", (False, True)) @pytest.mark.parametrize("weight_is_none", (False, True)) @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol, weight_is_none, bias_is_none, eps): - device = "cuda" - +def test_layer_norm( + shape, dtype, device, rtol, atol, weight_is_none, bias_is_none, eps +): input = torch.randn(shape, dtype=dtype, device=device) normalized_shape = shape[-random.randint(1, len(shape)) :] if weight_is_none: @@ -34,4 +34,4 @@ def test_cuda(shape, dtype, atol, rtol, weight_is_none, bias_is_none, eps): input, normalized_shape, weight=weight, bias=bias, eps=eps ) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_le.py b/tests/test_le.py index b6d3b28..5a6f961 100644 --- a/tests/test_le.py +++ b/tests/test_le.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_le(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_lt.py b/tests/test_lt.py index 1cda1f5..7d2b376 100644 --- a/tests/test_lt.py +++ b/tests/test_lt.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_lt(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_mm.py b/tests/test_mm.py index 8874f32..57bd387 100644 --- a/tests/test_mm.py +++ b/tests/test_mm.py @@ -11,6 +11,8 @@ def generate_arguments(): arguments = [] for dtype in (torch.float32, torch.float16): + device = "cuda" + if dtype is torch.float32: atol = 0.001 rtol = 0.001 @@ -25,20 +27,18 @@ def generate_random_size(): n = generate_random_size() k = generate_random_size() - arguments.append((m, n, k, dtype, atol, rtol)) + arguments.append((m, n, k, dtype, device, rtol, atol)) - return "m, n, k, dtype, atol, rtol", arguments + return "m, n, k, dtype, device, rtol, atol", arguments @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(m, n, k, dtype, atol, rtol): - device = "cuda" - +def test_mm(m, n, k, dtype, device, rtol, atol): input = torch.randn((m, k), dtype=dtype, device=device) other = torch.randn((k, n), dtype=dtype, device=device) ninetoothed_output = ntops.torch.mm(input, other) reference_output = torch.mm(input, other) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_mul.py b/tests/test_mul.py index 78d7951..d707cf9 100644 --- a/tests/test_mul.py +++ b/tests/test_mul.py @@ -8,13 +8,11 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_mul(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.mul(input, other) reference_output = torch.mul(input, other) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_ne.py b/tests/test_ne.py index d1ca7ec..e6568c5 100644 --- a/tests/test_ne.py +++ b/tests/test_ne.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_ne(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_neg.py b/tests/test_neg.py index 243a6b1..4f77f20 100644 --- a/tests/test_neg.py +++ b/tests/test_neg.py @@ -8,12 +8,10 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_neg(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.neg(input) reference_output = torch.neg(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_pow.py b/tests/test_pow.py index 5ffc1a0..a1df1b0 100644 --- a/tests/test_pow.py +++ b/tests/test_pow.py @@ -8,13 +8,11 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): +def test_pow(shape, dtype, device, rtol, atol): # TODO: Test for `float16` later. if dtype is torch.float16: return - device = "cuda" - input = torch.randn(shape, dtype=dtype, device=device) exponent = torch.randn(shape, dtype=dtype, device=device) @@ -22,5 +20,5 @@ def test_cuda(shape, dtype, atol, rtol): reference_output = torch.pow(input, exponent) assert torch.allclose( - ninetoothed_output, reference_output, atol=atol, rtol=rtol, equal_nan=True + ninetoothed_output, reference_output, rtol=rtol, atol=atol, equal_nan=True ) diff --git a/tests/test_relu.py b/tests/test_relu.py index 9fdade5..28cc07c 100644 --- a/tests/test_relu.py +++ b/tests/test_relu.py @@ -8,16 +8,12 @@ @skip_if_cuda_not_available +@pytest.mark.parametrize("inplace", (False, True)) @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_relu(shape, inplace, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) - for inplace in (False, True): - ninetoothed_output = ntops.torch.relu(input, inplace) - reference_output = F.relu(input, inplace) + ninetoothed_output = ntops.torch.relu(input, inplace) + reference_output = F.relu(input, inplace) - assert torch.allclose( - ninetoothed_output, reference_output, atol=atol, rtol=rtol - ) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index 646906a..211589d 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -12,9 +12,7 @@ @pytest.mark.parametrize("eps", (None, 0, 1e-5, 1e-3)) @pytest.mark.parametrize("weight_is_none", (False, True)) @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol, weight_is_none, eps): - device = "cuda" - +def test_rms_norm(shape, dtype, device, rtol, atol, weight_is_none, eps): input = torch.randn(shape, dtype=dtype, device=device) normalized_shape = shape[-random.randint(1, len(shape)) :] if weight_is_none: @@ -29,4 +27,4 @@ def test_cuda(shape, dtype, atol, rtol, weight_is_none, eps): input, normalized_shape, weight=weight, eps=eps ) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_rotary_position_embedding.py b/tests/test_rotary_position_embedding.py index a017c9d..bb2749c 100644 --- a/tests/test_rotary_position_embedding.py +++ b/tests/test_rotary_position_embedding.py @@ -46,8 +46,9 @@ def _generate_sin_and_cos_tables( @skip_if_cuda_not_available +@pytest.mark.parametrize("device", ("cuda",)) @pytest.mark.parametrize( - "dtype, atol, rtol", ((torch.float32, 0.001, 0), (torch.float16, 0.001, 0.001)) + "dtype, rtol, atol", ((torch.float32, 0, 0.001), (torch.float16, 0.001, 0.001)) ) @pytest.mark.parametrize("inplace", (False, True)) @pytest.mark.parametrize("interleaved", (False, True)) @@ -55,11 +56,18 @@ def _generate_sin_and_cos_tables( @pytest.mark.parametrize("num_heads", (1, 8)) @pytest.mark.parametrize("seq_len", (1, 128)) @pytest.mark.parametrize("batch_size", (1, 4)) -def test_cuda( - batch_size, seq_len, num_heads, emb_dim, interleaved, inplace, dtype, atol, rtol +def test_rotary_position_embedding( + batch_size, + seq_len, + num_heads, + emb_dim, + interleaved, + inplace, + dtype, + device, + rtol, + atol, ): - device = "cuda" - input = torch.randn( batch_size, seq_len, num_heads, emb_dim, dtype=dtype, device=device ) @@ -78,4 +86,4 @@ def test_cuda( input, sin_table, cos_table, interleaved=interleaved ) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_rsqrt.py b/tests/test_rsqrt.py index d87fafb..35252dd 100644 --- a/tests/test_rsqrt.py +++ b/tests/test_rsqrt.py @@ -8,14 +8,12 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_rsqrt(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.rsqrt(input) reference_output = torch.rsqrt(input) assert torch.allclose( - ninetoothed_output, reference_output, atol=atol, rtol=rtol, equal_nan=True + ninetoothed_output, reference_output, rtol=rtol, atol=atol, equal_nan=True ) diff --git a/tests/test_scaled_dot_product_attention.py b/tests/test_scaled_dot_product_attention.py index 433ffa9..b1f1046 100644 --- a/tests/test_scaled_dot_product_attention.py +++ b/tests/test_scaled_dot_product_attention.py @@ -22,6 +22,7 @@ def _generate_random_size(): is_causal_values = (False, True) scales = (None, random.uniform(0.05, 0.5)) dtypes = (torch.float32, torch.float16) + devices = ("cuda",) with_kv_cache_values = (False, True) causal_variants = (None, CausalVariant.LOWER_RIGHT, CausalVariant.UPPER_LEFT) @@ -30,6 +31,7 @@ def _generate_random_size(): is_causal, scale, dtype, + device, with_kv_cache, causal_variant, ) in itertools.product( @@ -37,6 +39,7 @@ def _generate_random_size(): is_causal_values, scales, dtypes, + devices, with_kv_cache_values, causal_variants, ): @@ -77,20 +80,21 @@ def _generate_random_size(): causal_variant, with_kv_cache, dtype, - atol, + device, rtol, + atol, ) ) return ( - "batch_size, num_heads_q, seq_len_q, head_dim, num_heads_kv, seq_len_kv, attn_mask_type, is_causal, scale, enable_gqa, causal_variant, with_kv_cache, dtype, atol, rtol", + "batch_size, num_heads_q, seq_len_q, head_dim, num_heads_kv, seq_len_kv, attn_mask_type, is_causal, scale, enable_gqa, causal_variant, with_kv_cache, dtype, device, rtol, atol", arguments, ) @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda( +def test_scaled_dot_product_attention( batch_size, num_heads_q, seq_len_q, @@ -104,11 +108,10 @@ def test_cuda( causal_variant, with_kv_cache, dtype, - atol, + device, rtol, + atol, ): - device = "cuda" - shape_q = (batch_size, num_heads_q, seq_len_q, head_dim) shape_kv = (batch_size, num_heads_kv, seq_len_kv, head_dim) @@ -180,4 +183,4 @@ def _generate_present_and_slot(tensor): enable_gqa=enable_gqa, ) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_sigmoid.py b/tests/test_sigmoid.py index d4fc499..c369fa8 100644 --- a/tests/test_sigmoid.py +++ b/tests/test_sigmoid.py @@ -8,16 +8,14 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): +def test_sigmoid(shape, dtype, device, rtol, atol): # TODO: Test for `float16` later. if dtype is torch.float16: return - device = "cuda" - input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.sigmoid(input) reference_output = torch.sigmoid(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_silu.py b/tests/test_silu.py index cba8ae2..8aae00b 100644 --- a/tests/test_silu.py +++ b/tests/test_silu.py @@ -9,13 +9,11 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_silu(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) # TODO: Add `inplace` tests later. ninetoothed_output = ntops.torch.silu(input) reference_output = F.silu(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_sin.py b/tests/test_sin.py index d395b97..5e6cd35 100644 --- a/tests/test_sin.py +++ b/tests/test_sin.py @@ -8,16 +8,13 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): +def test_sin(shape, dtype, device, rtol, atol): # TODO: Test for `float16` later. if dtype is torch.float16: return - - device = "cuda" - input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.sin(input) reference_output = torch.sin(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_softmax.py b/tests/test_softmax.py index 48eaad0..ad86181 100644 --- a/tests/test_softmax.py +++ b/tests/test_softmax.py @@ -10,9 +10,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_softmax(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) dim = random.randint(0, input.ndim - 1) dtype = random.choice([torch.float16, torch.float32, torch.float64]) @@ -20,4 +18,4 @@ def test_cuda(shape, dtype, atol, rtol): ninetoothed_output = ntops.torch.softmax(input, dim, dtype) reference_output = torch.nn.functional.softmax(input, dim=dim, dtype=dtype) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_sub.py b/tests/test_sub.py index 726f78f..1909cf0 100644 --- a/tests/test_sub.py +++ b/tests/test_sub.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_sub(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) alpha = gauss() @@ -18,4 +16,4 @@ def test_cuda(shape, dtype, atol, rtol): ninetoothed_output = ntops.torch.sub(input, other, alpha=alpha) reference_output = torch.sub(input, other, alpha=alpha) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_tanh.py b/tests/test_tanh.py index e2bc49e..7b7d490 100644 --- a/tests/test_tanh.py +++ b/tests/test_tanh.py @@ -8,16 +8,14 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): +def test_tanh(shape, dtype, device, rtol, atol): # TODO: Test for `float16` later. if dtype is torch.float16: return - device = "cuda" - input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.tanh(input) reference_output = torch.tanh(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/utils.py b/tests/utils.py index 23af61d..ac1949f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,6 +13,8 @@ def generate_arguments(use_float=True): for ndim in range(1, 5): for dtype in dtype_arr: + device = "cuda" + if dtype is torch.float32: atol = 0.001 rtol = 0.001 @@ -20,9 +22,9 @@ def generate_arguments(use_float=True): atol = 0.01 rtol = 0.01 - arguments.append((_random_shape(ndim), dtype, atol, rtol)) + arguments.append((_random_shape(ndim), dtype, device, rtol, atol)) - return "shape, dtype, atol, rtol", arguments + return "shape, dtype, device, rtol, atol", arguments def gauss(mu=0.0, sigma=1.0):