Skip to content

Commit 9869285

Browse files
committed
refactored test functions
1 parent 7c1e52c commit 9869285

File tree

1 file changed

+19
-28
lines changed

1 file changed

+19
-28
lines changed

tests/test_sdp.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import inspect
2+
import warnings
3+
14
import naive
25
import numpy as np
36
import pytest
@@ -15,36 +18,24 @@
1518
]
1619

1720

18-
@pytest.mark.parametrize("Q, T", test_data)
19-
def test_njit_sliding_dot_product(Q, T):
20-
ref_mp = naive.rolling_window_dot_product(Q, T)
21-
comp_mp = sdp._njit_sliding_dot_product(Q, T)
22-
npt.assert_almost_equal(ref_mp, comp_mp)
23-
24-
25-
@pytest.mark.parametrize("Q, T", test_data)
26-
def test_convolve_sliding_dot_product(Q, T):
27-
ref_mp = naive.rolling_window_dot_product(Q, T)
28-
comp_mp = sdp._convolve_sliding_dot_product(Q, T)
29-
npt.assert_almost_equal(ref_mp, comp_mp)
30-
31-
32-
@pytest.mark.parametrize("Q, T", test_data)
33-
def test_oaconvolve_sliding_dot_product(Q, T):
34-
ref_mp = naive.rolling_window_dot_product(Q, T)
35-
comp_mp = sdp._oaconvolve_sliding_dot_product(Q, T)
36-
npt.assert_almost_equal(ref_mp, comp_mp)
21+
def get_sdp_function_names():
22+
out = []
23+
for func_name, func in inspect.getmembers(sdp, inspect.isfunction):
24+
if func_name.endswith("sliding_dot_product"):
25+
out.append(func_name)
3726

38-
39-
@pytest.mark.parametrize("Q, T", test_data)
40-
def test_pocketfft_sliding_dot_product(Q, T):
41-
ref_mp = naive.rolling_window_dot_product(Q, T)
42-
comp_mp = sdp._pocketfft_sliding_dot_product(Q, T)
43-
npt.assert_almost_equal(ref_mp, comp_mp)
27+
return out
4428

4529

4630
@pytest.mark.parametrize("Q, T", test_data)
4731
def test_sliding_dot_product(Q, T):
48-
ref_mp = naive.rolling_window_dot_product(Q, T)
49-
comp_mp = sdp._sliding_dot_product(Q, T)
50-
npt.assert_almost_equal(ref_mp, comp_mp)
32+
for func_name in get_sdp_function_names():
33+
func = getattr(sdp, func_name)
34+
try:
35+
comp = func(Q, T)
36+
ref = naive.rolling_window_dot_product(Q, T)
37+
npt.assert_allclose(comp, ref)
38+
except Exception as e: # pragma: no cover
39+
msg = f"Error in {func_name}, with n_Q={len(Q)} and n_T={len(T)}"
40+
warnings.warn(msg)
41+
raise e

0 commit comments

Comments
 (0)