Skip to content

Commit c20c8dc

Browse files
committed
add pyfftw_sdp, tests, and relevant fixes
1 parent a2b5d4e commit c20c8dc

4 files changed

Lines changed: 264 additions & 2 deletions

File tree

docstring.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def get_docstring_args(fd, file_name, func_name, class_name=None):
2323
msg += f"function/method: {func_name}\n"
2424
raise RuntimeError(msg)
2525

26-
if class_name is None:
26+
if len(re.findall(r"Returns", docstring)) > 0:
2727
params_section = re.findall(
2828
r"(?<=Parameters)(.*)(?=Returns)", docstring, re.DOTALL
2929
)[0]
@@ -43,7 +43,7 @@ def get_signature_args(fd):
4343
return set([a.arg for a in fd.args.args if a.arg != "self"])
4444

4545

46-
def check_args(doc_args, sig_args, file_name, func_name, class_name=None):
46+
def check_args(docstring_args, signature_args, file_name, func_name, class_name=None):
4747
"""
4848
Compare docstring arguments and signature argments
4949
"""

environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,5 @@ dependencies:
2727
- jupyterlab-myst>=2.0.0
2828
- myst-nb>=1.0.0
2929
- polars>=1.14.0
30+
- fftw>=3.3
31+
- pyfftw>=0.15.0

stumpy/sdp.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66

77
from . import config
88

9+
try:
10+
import pyfftw
11+
12+
FFTW_IS_AVAILABLE = True
13+
except ImportError: # pragma: no cover
14+
FFTW_IS_AVAILABLE = False
15+
916

1017
@njit(fastmath=config.STUMPY_FASTMATH_TRUE)
1118
def _njit_sliding_dot_product(Q, T):
@@ -99,3 +106,180 @@ def _pocketfft_sliding_dot_product(Q, T):
99106
fft_2d = r2c(True, tmp, axis=-1)
100107

101108
return c2r(False, np.multiply(fft_2d[0], fft_2d[1]), n=next_fast_n)[m - 1 : n]
109+
110+
111+
class _PYFFTW_SLIDING_DOT_PRODUCT:
112+
"""
113+
A class to compute the sliding dot product using FFTW via pyfftw.
114+
115+
This class uses FFTW (via pyfftw) to efficiently compute the sliding dot product
116+
between a query sequence Q and a time series T. It preallocates arrays and caches
117+
FFTW objects to optimize repeated computations with similar-sized inputs.
118+
119+
Parameters
120+
----------
121+
max_n : int, default=2**20
122+
Maximum length to preallocate arrays for. This will be the size of the
123+
real-valued array. A complex-valued array of size `1 + (max_n // 2)`
124+
will also be preallocated. If inputs exceed this size, arrays will be
125+
reallocated to accommodate larger sizes.
126+
127+
Attributes
128+
----------
129+
real_arr : pyfftw.empty_aligned
130+
Preallocated real-valued array for FFTW computations.
131+
132+
complex_arr : pyfftw.empty_aligned
133+
Preallocated complex-valued array for FFTW computations.
134+
135+
rfft_objects : dict
136+
Cache of FFTW forward transform objects, keyed by
137+
(next_fast_n, n_threads, planning_flag).
138+
139+
irfft_objects : dict
140+
Cache of FFTW inverse transform objects, keyed by
141+
(next_fast_n, n_threads, planning_flag).
142+
143+
Notes
144+
-----
145+
The class maintains internal caches of FFTW objects to avoid redundant planning
146+
operations when called multiple times with similar-sized inputs and parameters.
147+
148+
Examples
149+
--------
150+
>>> sdp_obj = _PYFFTW_SLIDING_DOT_PRODUCT(max_n=1000)
151+
>>> Q = np.array([1, 2, 3])
152+
>>> T = np.array([4, 5, 6, 7, 8])
153+
>>> result = sdp_obj(Q, T)
154+
155+
References
156+
----------
157+
`FFTW documentation <http://www.fftw.org/>`__
158+
159+
`pyfftw documentation <https://pyfftw.readthedocs.io/>`__
160+
"""
161+
162+
def __init__(self, max_n=2**20):
163+
"""
164+
Initialize the `_PYFFTW_SLIDING_DOT_PRODUCT` object, which can be called
165+
to compute the sliding dot product using FFTW via pyfftw.
166+
167+
Parameters
168+
----------
169+
max_n : int, default=2**20
170+
Maximum length to preallocate arrays for. This will be the size of the
171+
real-valued array. A complex-valued array of size `1 + (max_n // 2)`
172+
will also be preallocated.
173+
"""
174+
# Preallocate arrays
175+
self.real_arr = pyfftw.empty_aligned(max_n, dtype="float64")
176+
self.complex_arr = pyfftw.empty_aligned(1 + (max_n // 2), dtype="complex128")
177+
178+
# Store FFTW objects, keyed by (next_fast_n, n_threads, planning_flag)
179+
self.rfft_objects = {}
180+
self.irfft_objects = {}
181+
182+
def __call__(self, Q, T, n_threads=1, planning_flag="FFTW_ESTIMATE"):
183+
"""
184+
Compute the sliding dot product between `Q` and `T` using FFTW via pyfftw,
185+
and cache FFTW objects if not already cached.
186+
187+
Parameters
188+
----------
189+
Q : numpy.ndarray
190+
Query array or subsequence.
191+
192+
T : numpy.ndarray
193+
Time series or sequence.
194+
195+
n_threads : int, default=1
196+
Number of threads to use for FFTW computations.
197+
198+
planning_flag : str, default="FFTW_ESTIMATE"
199+
The planning flag that will be used in FFTW for planning.
200+
See pyfftw documentation for details. Current options, ordered
201+
ascendingly by the level of aggressiveness in planning, are:
202+
"FFTW_ESTIMATE", "FFTW_MEASURE", "FFTW_PATIENT", and "FFTW_EXHAUSTIVE".
203+
The more aggressive the planning, the longer the planning time, but
204+
the faster the execution time.
205+
206+
Returns
207+
-------
208+
out : numpy.ndarray
209+
Sliding dot product between `Q` and `T`.
210+
211+
Notes
212+
-----
213+
The planning_flag is defaulted to "FFTW_ESTIMATE" to be aligned with
214+
MATLAB's FFTW usage (as of version R2025b)
215+
See: https://www.mathworks.com/help/matlab/ref/fftw.html
216+
217+
This implementation is inspired by the answer on StackOverflow:
218+
https://stackoverflow.com/a/30615425/2955541
219+
"""
220+
m = Q.shape[0]
221+
n = T.shape[0]
222+
next_fast_n = pyfftw.next_fast_len(n)
223+
224+
# Update preallocated arrays if needed
225+
if next_fast_n > len(self.real_arr):
226+
self.real_arr = pyfftw.empty_aligned(next_fast_n, dtype="float64")
227+
self.complex_arr = pyfftw.empty_aligned(
228+
1 + (next_fast_n // 2), dtype="complex128"
229+
)
230+
231+
real_arr = self.real_arr[:next_fast_n]
232+
complex_arr = self.complex_arr[: 1 + (next_fast_n // 2)]
233+
234+
# Get or create FFTW objects
235+
key = (next_fast_n, n_threads, planning_flag)
236+
237+
rfft_obj = self.rfft_objects.get(key, None)
238+
if rfft_obj is None:
239+
rfft_obj = pyfftw.FFTW(
240+
input_array=real_arr,
241+
output_array=complex_arr,
242+
direction="FFTW_FORWARD",
243+
flags=(planning_flag,),
244+
threads=n_threads,
245+
)
246+
self.rfft_objects[key] = rfft_obj
247+
else:
248+
rfft_obj.update_arrays(real_arr, complex_arr)
249+
250+
irfft_obj = self.irfft_objects.get(key, None)
251+
if irfft_obj is None:
252+
irfft_obj = pyfftw.FFTW(
253+
input_array=complex_arr,
254+
output_array=real_arr,
255+
direction="FFTW_BACKWARD",
256+
flags=(planning_flag, "FFTW_DESTROY_INPUT"),
257+
threads=n_threads,
258+
)
259+
self.irfft_objects[key] = irfft_obj
260+
else:
261+
irfft_obj.update_arrays(complex_arr, real_arr)
262+
263+
# RFFT(T)
264+
real_arr[:n] = T
265+
real_arr[n:] = 0.0
266+
rfft_obj.execute() # output is in complex_arr
267+
complex_arr_T = complex_arr.copy()
268+
269+
# RFFT(Q)
270+
# Scale by 1/next_fast_n to account for
271+
# FFTW's unnormalized inverse FFT via execute()
272+
np.multiply(Q[::-1], 1.0 / next_fast_n, out=real_arr[:m])
273+
real_arr[m:] = 0.0
274+
rfft_obj.execute() # output is in complex_arr
275+
276+
# RFFT(T) * RFFT(Q)
277+
np.multiply(complex_arr, complex_arr_T, out=complex_arr)
278+
279+
# IRFFT (input is in complex_arr)
280+
irfft_obj.execute() # output is in real_arr
281+
282+
return real_arr[m - 1 : n]
283+
284+
285+
_pyfftw_sliding_dot_product = _PYFFTW_SLIDING_DOT_PRODUCT()

tests/test_sdp.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,79 @@ def test_sdp_power2():
172172
raise e
173173

174174
return
175+
176+
177+
def test_pyfftw_sdp_max_n():
178+
# When `len(T)` larger than `max_n` in pyfftw_sdp,
179+
# the internal preallocated arrays should be resized.
180+
# This test checks that functionality.
181+
sliding_dot_product = sdp._PYFFTW_SLIDING_DOT_PRODUCT(max_n=2**10)
182+
183+
# len(T) > max_n to trigger array resizing
184+
T = np.random.rand(2**11)
185+
Q = np.random.rand(2**8)
186+
187+
comp = sliding_dot_product(Q, T)
188+
ref = naive_sliding_dot_product(Q, T)
189+
190+
np.testing.assert_allclose(comp, ref)
191+
192+
return
193+
194+
195+
def test_pyfftw_sdp_cache():
196+
# To ensure that the caching mechanism in
197+
# pyfftw_sdp is working as intended
198+
sliding_dot_product = sdp._PYFFTW_SLIDING_DOT_PRODUCT(max_n=2**10)
199+
assert sliding_dot_product.rfft_objects == {}
200+
assert sliding_dot_product.irfft_objects == {}
201+
202+
T = np.random.rand(2**5)
203+
Q = np.random.rand(2**2)
204+
205+
n_threads = 1
206+
planning_flag = "FFTW_ESTIMATE"
207+
sliding_dot_product(Q, T, n_threads=n_threads, planning_flag=planning_flag)
208+
209+
# Check that the FFTW objects are cached
210+
expected_key = (len(T), n_threads, planning_flag)
211+
assert expected_key in sliding_dot_product.rfft_objects
212+
assert expected_key in sliding_dot_product.irfft_objects
213+
214+
return
215+
216+
217+
def test_pyfftw_sdp_update_arrays():
218+
# To ensure that the cached FFTW objects
219+
# can be reused when preallocated arrays
220+
# are updated.
221+
sliding_dot_product = sdp._PYFFTW_SLIDING_DOT_PRODUCT(max_n=2**10)
222+
223+
n_threads = 1
224+
planning_flag = "FFTW_ESTIMATE"
225+
226+
T1 = np.random.rand(2**5)
227+
Q1 = np.random.rand(2**2)
228+
sliding_dot_product(Q1, T1, n_threads=n_threads, planning_flag=planning_flag)
229+
230+
# len(T2) > max_n to trigger array resizing
231+
T2 = np.random.rand(2**11)
232+
Q2 = np.random.rand(2**3)
233+
sliding_dot_product(Q2, T2, n_threads=n_threads, planning_flag=planning_flag)
234+
235+
# Check if the FFTW objects cached for inputs (Q1, T1)
236+
# can be reused when preallocated arrays are resized
237+
# after calling with (Q2, T2)
238+
key1 = (len(T1), n_threads, planning_flag)
239+
rfft_obj_before = sliding_dot_product.rfft_objects[key1]
240+
irfft_obj_before = sliding_dot_product.irfft_objects[key1]
241+
242+
comp = sliding_dot_product(Q1, T1, n_threads=n_threads, planning_flag=planning_flag)
243+
ref = naive_sliding_dot_product(Q1, T1)
244+
245+
# test for correctness
246+
np.testing.assert_allclose(comp, ref)
247+
248+
# Check that the same FFTW objects are reused
249+
assert sliding_dot_product.rfft_objects[key1] is rfft_obj_before
250+
assert sliding_dot_product.irfft_objects[key1] is irfft_obj_before

0 commit comments

Comments
 (0)