Skip to content

Commit 0a3427e

Browse files
committed
improved readability
1 parent 6051ab4 commit 0a3427e

1 file changed

Lines changed: 18 additions & 13 deletions

File tree

stumpy/sdp.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -267,29 +267,34 @@ def __call__(self, Q, T, n_threads=1, planning_flag="FFTW_ESTIMATE"):
267267
self.rfft_objects[key] = rfft_obj
268268
self.irfft_objects[key] = irfft_obj
269269
else:
270+
# Update the input and output arrays of the cached FFTW objects
271+
# in case their original input and output arrays were reallocated
272+
# in a previous call
270273
rfft_obj.update_arrays(real_arr, complex_arr)
271274
irfft_obj.update_arrays(complex_arr, real_arr)
272275

273276
# Compute RFFT of T
274277
real_arr[:n] = T
275278
real_arr[n:] = 0.0
276-
rfft_obj.execute() # output is stored in complex_arr
277-
278-
# need to make a copy since the array will be
279-
# overwritten later during the RFFT(Q) step
280-
rfft_of_T = complex_arr.copy()
281-
282-
# Compute RFFT of Q (reversed and scaled by 1/next_fast_n)
279+
rfft_obj.execute()
280+
rfft_of_T = rfft_obj.output_array.copy()
281+
# output array is stored in complex_arr, so make a copy
282+
# to avoid losing it when it is overwritten when computing
283+
# the RFFT of Q
284+
285+
# Compute RFFT of Q (reversed, scaled, and zero-padded)
286+
# Note: scaling is needed since the thin wrapper `execute`
287+
# is called which does not perform normalization
283288
np.multiply(Q[::-1], 1.0 / next_fast_n, out=real_arr[:m])
284289
real_arr[m:] = 0.0
285-
rfft_obj.execute() # output is stored in complex_arr
286-
rfft_of_Q = complex_arr
290+
rfft_obj.execute()
291+
rfft_of_Q = rfft_obj.output_array
287292

288-
# Compute IRFFT of the element-wise product of the RFFTs
289-
np.multiply(rfft_of_Q, rfft_of_T, out=complex_arr)
290-
irfft_obj.execute() # output is stored in real_arr
293+
# Convert back to time domain by taking the inverse RFFT
294+
np.multiply(rfft_of_T, rfft_of_Q, out=irfft_obj.input_array)
295+
irfft_obj.execute()
291296

292-
return real_arr[m - 1 : n]
297+
return irfft_obj.output_array[m - 1 : n]
293298

294299

295300
if PYFFTW_IS_AVAILABLE: # pragma: no cover

0 commit comments

Comments
 (0)