From 3a5a680ced1a22c0478c2c2fb1bd944f517206ae Mon Sep 17 00:00:00 2001 From: Ivana Date: Sat, 30 May 2026 09:38:21 +0000 Subject: [PATCH] fix(sum): use pairwise summation for floating-point linalg::Sum linalg::Sum accumulated every element into one scalar in a serial loop, whose worst-case rounding error grows as O(N * eps). The Accuracy test in tests/linalg_test/sum_test.cpp (10000 identical floating-point values whose bit patterns are chosen to expose round-off) failed under naive accumulation and was disabled. Add a pairwise-summation primitive (pairwise_sum.hpp): PairwiseSum mirrors NumPy's np.add.reduce -- a straight loop below 8 elements, an eight-accumulator unrolled block up to 128, and a halve-and-recurse split rounded to a multiple of eight above that. Worst-case error growth drops to O(log N * eps) at essentially the same cost. PairwiseSum consumes any random-access range, so it also serves strided reductions later. Rewrite the Float/Double/ComplexFloat/ComplexDouble branches of Sum_internal to call PairwiseSum over the contiguous storage span and re-enable the accuracy test. Integer branches keep the straight loop since they do not round. The GPU path (cuReduce_gpu) already performs a block/tree reduction with bounded error, so it is left unchanged. For the disabled test's values the new code lands within 2 ULP of the reference, against the 4-ULP tolerance of EXPECT_NUMBER_EQ; naive accumulation was 108 ULP off. Co-authored-by: Claude --- .../linalg_internal_cpu/Sum_internal.cpp | 23 +++---- .../linalg_internal_cpu/pairwise_sum.hpp | 62 +++++++++++++++++++ tests/linalg_test/sum_test.cpp | 2 +- 3 files changed, 70 insertions(+), 17 deletions(-) create mode 100644 src/backend/linalg_internal_cpu/pairwise_sum.hpp diff --git a/src/backend/linalg_internal_cpu/Sum_internal.cpp b/src/backend/linalg_internal_cpu/Sum_internal.cpp index df871390d..7fef6385c 100644 --- a/src/backend/linalg_internal_cpu/Sum_internal.cpp +++ b/src/backend/linalg_internal_cpu/Sum_internal.cpp @@ -1,8 +1,11 @@ #include "Sum_internal.hpp" +#include + #include "boost/smart_ptr/intrusive_ptr.hpp" #include "backend/Storage.hpp" +#include "backend/linalg_internal_cpu/pairwise_sum.hpp" #include "cytnx_error.hpp" #include "Type.hpp" @@ -88,10 +91,7 @@ namespace cytnx { cytnx_double *_ten = (cytnx_double *)ten->data(); cytnx_double *_out = (cytnx_double *)out->data(); - _out[0] = 0; - for (cytnx_uint64 n = 0; n < Nelem; n++) { - _out[0] += _ten[n]; - } + _out[0] = PairwiseSum(std::span(_ten, Nelem)); } void Sum_internal_f(boost::intrusive_ptr &out, @@ -99,20 +99,14 @@ namespace cytnx { cytnx_float *_ten = (cytnx_float *)ten->data(); cytnx_float *_out = (cytnx_float *)out->data(); - _out[0] = 0; - for (cytnx_uint64 n = 0; n < Nelem; n++) { - _out[0] += _ten[n]; - } + _out[0] = PairwiseSum(std::span(_ten, Nelem)); } void Sum_internal_cd(boost::intrusive_ptr &out, const boost::intrusive_ptr &ten, const cytnx_uint64 &Nelem) { cytnx_complex128 *_ten = (cytnx_complex128 *)ten->data(); cytnx_complex128 *_out = (cytnx_complex128 *)out->data(); - _out[0] = 0; - for (cytnx_uint64 n = 0; n < Nelem; n++) { - _out[0] += _ten[n]; - } + _out[0] = PairwiseSum(std::span(_ten, Nelem)); } void Sum_internal_cf(boost::intrusive_ptr &out, @@ -120,10 +114,7 @@ namespace cytnx { cytnx_complex64 *_ten = (cytnx_complex64 *)ten->data(); cytnx_complex64 *_out = (cytnx_complex64 *)out->data(); - _out[0] = 0; - for (cytnx_uint64 n = 0; n < Nelem; n++) { - _out[0] += _ten[n]; - } + _out[0] = PairwiseSum(std::span(_ten, Nelem)); } void Sum_internal_b(boost::intrusive_ptr &out, diff --git a/src/backend/linalg_internal_cpu/pairwise_sum.hpp b/src/backend/linalg_internal_cpu/pairwise_sum.hpp new file mode 100644 index 000000000..2680307ff --- /dev/null +++ b/src/backend/linalg_internal_cpu/pairwise_sum.hpp @@ -0,0 +1,62 @@ +#ifndef CYTNX_BACKEND_LINALG_INTERNAL_CPU_PAIRWISE_SUM_H_ +#define CYTNX_BACKEND_LINALG_INTERNAL_CPU_PAIRWISE_SUM_H_ + +#include +#include +#include + +namespace cytnx { + namespace linalg_internal { + + // Recursive (divide-and-conquer) core of the pairwise summation, matching + // NumPy's np.add.reduce: a straight loop for the smallest blocks, an + // eight-accumulator unrolled loop up to 128 elements, and a split into two + // halves (rounded to a multiple of eight) above that. Worst-case rounding + // error grows as O(log N * eps) instead of the O(N * eps) of a naive serial + // accumulation, at essentially the same cost. + template + T PairwiseSumBlocks(It first, std::size_t n) { + if (n < 8) { + T res = T(0); + for (std::size_t i = 0; i < n; ++i) res += first[static_cast(i)]; + return res; + } + if (n <= 128) { + T r0 = first[0], r1 = first[1], r2 = first[2], r3 = first[3]; + T r4 = first[4], r5 = first[5], r6 = first[6], r7 = first[7]; + std::size_t i = 8; + for (; i + 8 <= n; i += 8) { + auto p = first + static_cast(i); + r0 += p[0]; + r1 += p[1]; + r2 += p[2]; + r3 += p[3]; + r4 += p[4]; + r5 += p[5]; + r6 += p[6]; + r7 += p[7]; + } + T res = ((r0 + r1) + (r2 + r3)) + ((r4 + r5) + (r6 + r7)); + for (; i < n; ++i) res += first[static_cast(i)]; + return res; + } + std::size_t half = n / 2; + half -= half % 8; + return PairwiseSumBlocks(first, half) + + PairwiseSumBlocks(first + static_cast(half), n - half); + } + + // Pairwise sum over a random-access range. The element type is deduced from + // the range. A contiguous std::span sums every element; pass a strided view + // (see stride_view.hpp) to sum a strided sequence such as a matrix diagonal. + template + std::ranges::range_value_t PairwiseSum(R&& range) { + using T = std::ranges::range_value_t; + return PairwiseSumBlocks(std::ranges::begin(range), + static_cast(std::ranges::size(range))); + } + + } // namespace linalg_internal +} // namespace cytnx + +#endif // CYTNX_BACKEND_LINALG_INTERNAL_CPU_PAIRWISE_SUM_H_ diff --git a/tests/linalg_test/sum_test.cpp b/tests/linalg_test/sum_test.cpp index d130ec71c..23b7a4aed 100644 --- a/tests/linalg_test/sum_test.cpp +++ b/tests/linalg_test/sum_test.cpp @@ -49,7 +49,7 @@ namespace cytnx { * Note: `cytnx_bool` is not supported for the `linalg::Sum()` function. * This test also assesses the accuracy of summing floating-point numbers. */ - TYPED_TEST(LinalgSumHomogeneousValuesTest, DISABLED_Accuracy) { + TYPED_TEST(LinalgSumHomogeneousValuesTest, Accuracy) { TypeParam value = LinalgSumHomogeneousValuesTest::value; int element_number = 10000; unsigned int dtype = Type_class().cy_typeid(value);