Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions src/backend/linalg_internal_cpu/Sum_internal.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#include "Sum_internal.hpp"

#include <span>

#include "boost/smart_ptr/intrusive_ptr.hpp"

#include "backend/Storage.hpp"
#include "backend/linalg_internal_cpu/pairwise_sum.hpp"
#include "cytnx_error.hpp"
#include "Type.hpp"

Expand Down Expand Up @@ -88,42 +91,30 @@ namespace cytnx {
cytnx_double *_ten = (cytnx_double *)ten->data();
cytnx_double *_out = (cytnx_double *)out->data();

_out[0] = 0;
for (cytnx_uint64 n = 0; n < Nelem; n++) {
_out[0] += _ten[n];
}
_out[0] = PairwiseSum(std::span<const cytnx_double>(_ten, Nelem));
}

void Sum_internal_f(boost::intrusive_ptr<Storage_base> &out,
const boost::intrusive_ptr<Storage_base> &ten, const cytnx_uint64 &Nelem) {
cytnx_float *_ten = (cytnx_float *)ten->data();
cytnx_float *_out = (cytnx_float *)out->data();

_out[0] = 0;
for (cytnx_uint64 n = 0; n < Nelem; n++) {
_out[0] += _ten[n];
}
_out[0] = PairwiseSum(std::span<const cytnx_float>(_ten, Nelem));
}
void Sum_internal_cd(boost::intrusive_ptr<Storage_base> &out,
const boost::intrusive_ptr<Storage_base> &ten, const cytnx_uint64 &Nelem) {
cytnx_complex128 *_ten = (cytnx_complex128 *)ten->data();
cytnx_complex128 *_out = (cytnx_complex128 *)out->data();

_out[0] = 0;
for (cytnx_uint64 n = 0; n < Nelem; n++) {
_out[0] += _ten[n];
}
_out[0] = PairwiseSum(std::span<const cytnx_complex128>(_ten, Nelem));
}

void Sum_internal_cf(boost::intrusive_ptr<Storage_base> &out,
const boost::intrusive_ptr<Storage_base> &ten, const cytnx_uint64 &Nelem) {
cytnx_complex64 *_ten = (cytnx_complex64 *)ten->data();
cytnx_complex64 *_out = (cytnx_complex64 *)out->data();

_out[0] = 0;
for (cytnx_uint64 n = 0; n < Nelem; n++) {
_out[0] += _ten[n];
}
_out[0] = PairwiseSum(std::span<const cytnx_complex64>(_ten, Nelem));
}

void Sum_internal_b(boost::intrusive_ptr<Storage_base> &out,
Expand Down
62 changes: 62 additions & 0 deletions src/backend/linalg_internal_cpu/pairwise_sum.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#ifndef CYTNX_BACKEND_LINALG_INTERNAL_CPU_PAIRWISE_SUM_H_
#define CYTNX_BACKEND_LINALG_INTERNAL_CPU_PAIRWISE_SUM_H_

#include <cstddef>
#include <iterator>
#include <ranges>

namespace cytnx {
namespace linalg_internal {

// Recursive (divide-and-conquer) core of the pairwise summation, matching
// NumPy's np.add.reduce: a straight loop for the smallest blocks, an
// eight-accumulator unrolled loop up to 128 elements, and a split into two
// halves (rounded to a multiple of eight) above that. Worst-case rounding
// error grows as O(log N * eps) instead of the O(N * eps) of a naive serial
// accumulation, at essentially the same cost.
template <class T, std::random_access_iterator It>
T PairwiseSumBlocks(It first, std::size_t n) {
if (n < 8) {
T res = T(0);
for (std::size_t i = 0; i < n; ++i) res += first[static_cast<std::ptrdiff_t>(i)];
return res;
}
if (n <= 128) {
T r0 = first[0], r1 = first[1], r2 = first[2], r3 = first[3];
T r4 = first[4], r5 = first[5], r6 = first[6], r7 = first[7];
std::size_t i = 8;
for (; i + 8 <= n; i += 8) {
auto p = first + static_cast<std::ptrdiff_t>(i);
r0 += p[0];
r1 += p[1];
r2 += p[2];
r3 += p[3];
r4 += p[4];
r5 += p[5];
r6 += p[6];
r7 += p[7];
}
T res = ((r0 + r1) + (r2 + r3)) + ((r4 + r5) + (r6 + r7));
for (; i < n; ++i) res += first[static_cast<std::ptrdiff_t>(i)];
return res;
}
std::size_t half = n / 2;
half -= half % 8;
return PairwiseSumBlocks<T>(first, half) +
PairwiseSumBlocks<T>(first + static_cast<std::ptrdiff_t>(half), n - half);
}

// Pairwise sum over a random-access range. The element type is deduced from
// the range. A contiguous std::span sums every element; pass a strided view
// (see stride_view.hpp) to sum a strided sequence such as a matrix diagonal.
template <std::ranges::random_access_range R>
std::ranges::range_value_t<R> PairwiseSum(R&& range) {
using T = std::ranges::range_value_t<R>;
return PairwiseSumBlocks<T>(std::ranges::begin(range),
static_cast<std::size_t>(std::ranges::size(range)));
}

} // namespace linalg_internal
} // namespace cytnx

#endif // CYTNX_BACKEND_LINALG_INTERNAL_CPU_PAIRWISE_SUM_H_
2 changes: 1 addition & 1 deletion tests/linalg_test/sum_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace cytnx {
* Note: `cytnx_bool` is not supported for the `linalg::Sum()` function.
* This test also assesses the accuracy of summing floating-point numbers.
*/
TYPED_TEST(LinalgSumHomogeneousValuesTest, DISABLED_Accuracy) {
TYPED_TEST(LinalgSumHomogeneousValuesTest, Accuracy) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Review written by Claude on behalf of @IvanaGyro.)

All the FP accuracy assertions here use homogeneous (identical) values, the one input distribution where even naive serial summation stays nearly exact. That doesn't demonstrate or guard the actual benefit of pairwise summation — improved accuracy under a large dynamic range of magnitudes.

Consider one heterogeneous-magnitude test for cytnx_double/cytnx_float where naive accumulation visibly loses precision but pairwise recovers it, e.g. a tensor like [1e16, 1, 1, ..., 1, -1e16] (the 1s vanish in a naive running sum once it reaches 1e16, but the balanced tree keeps them). That both documents intent and locks in the O(log N) error behavior.

TypeParam value = LinalgSumHomogeneousValuesTest<TypeParam>::value;
int element_number = 10000;
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Review written by Claude on behalf of @IvanaGyro.)

This single element_number = 10000 case only exercises the recursive-split branch (plus the 8-accumulator base). The two boundary branches in PairwiseSumBlocks go untested: the serial n < 8 path and the 8 <= n <= 128 unrolled path, including the off-by-one tails (n not a multiple of 8). A regression that broke either branch would still pass CI.

Suggest adding sizes that straddle the thresholds, e.g. 7, 8, 9, 128, 129, in addition to 10000 — a small value-parameterized companion would cover all three code paths cheaply.

unsigned int dtype = Type_class().cy_typeid(value);
Expand Down
Loading