-
Notifications
You must be signed in to change notification settings - Fork 22
fix(sum): use pairwise summation for floating-point linalg::Sum #849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| #ifndef CYTNX_BACKEND_LINALG_INTERNAL_CPU_PAIRWISE_SUM_H_ | ||
| #define CYTNX_BACKEND_LINALG_INTERNAL_CPU_PAIRWISE_SUM_H_ | ||
|
|
||
| #include <cstddef> | ||
| #include <iterator> | ||
| #include <ranges> | ||
|
|
||
| namespace cytnx { | ||
| namespace linalg_internal { | ||
|
|
||
| // Recursive (divide-and-conquer) core of the pairwise summation, matching | ||
| // NumPy's np.add.reduce: a straight loop for the smallest blocks, an | ||
| // eight-accumulator unrolled loop up to 128 elements, and a split into two | ||
| // halves (rounded to a multiple of eight) above that. Worst-case rounding | ||
| // error grows as O(log N * eps) instead of the O(N * eps) of a naive serial | ||
| // accumulation, at essentially the same cost. | ||
| template <class T, std::random_access_iterator It> | ||
| T PairwiseSumBlocks(It first, std::size_t n) { | ||
| if (n < 8) { | ||
| T res = T(0); | ||
| for (std::size_t i = 0; i < n; ++i) res += first[static_cast<std::ptrdiff_t>(i)]; | ||
| return res; | ||
| } | ||
| if (n <= 128) { | ||
| T r0 = first[0], r1 = first[1], r2 = first[2], r3 = first[3]; | ||
| T r4 = first[4], r5 = first[5], r6 = first[6], r7 = first[7]; | ||
| std::size_t i = 8; | ||
| for (; i + 8 <= n; i += 8) { | ||
| auto p = first + static_cast<std::ptrdiff_t>(i); | ||
| r0 += p[0]; | ||
| r1 += p[1]; | ||
| r2 += p[2]; | ||
| r3 += p[3]; | ||
| r4 += p[4]; | ||
| r5 += p[5]; | ||
| r6 += p[6]; | ||
| r7 += p[7]; | ||
| } | ||
| T res = ((r0 + r1) + (r2 + r3)) + ((r4 + r5) + (r6 + r7)); | ||
| for (; i < n; ++i) res += first[static_cast<std::ptrdiff_t>(i)]; | ||
| return res; | ||
| } | ||
| std::size_t half = n / 2; | ||
| half -= half % 8; | ||
| return PairwiseSumBlocks<T>(first, half) + | ||
| PairwiseSumBlocks<T>(first + static_cast<std::ptrdiff_t>(half), n - half); | ||
| } | ||
|
|
||
| // Pairwise sum over a random-access range. The element type is deduced from | ||
| // the range. A contiguous std::span sums every element; pass a strided view | ||
| // (see stride_view.hpp) to sum a strided sequence such as a matrix diagonal. | ||
| template <std::ranges::random_access_range R> | ||
| std::ranges::range_value_t<R> PairwiseSum(R&& range) { | ||
| using T = std::ranges::range_value_t<R>; | ||
| return PairwiseSumBlocks<T>(std::ranges::begin(range), | ||
| static_cast<std::size_t>(std::ranges::size(range))); | ||
| } | ||
|
|
||
| } // namespace linalg_internal | ||
| } // namespace cytnx | ||
|
|
||
| #endif // CYTNX_BACKEND_LINALG_INTERNAL_CPU_PAIRWISE_SUM_H_ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,7 +49,7 @@ namespace cytnx { | |
| * Note: `cytnx_bool` is not supported for the `linalg::Sum()` function. | ||
| * This test also assesses the accuracy of summing floating-point numbers. | ||
| */ | ||
| TYPED_TEST(LinalgSumHomogeneousValuesTest, DISABLED_Accuracy) { | ||
| TYPED_TEST(LinalgSumHomogeneousValuesTest, Accuracy) { | ||
| TypeParam value = LinalgSumHomogeneousValuesTest<TypeParam>::value; | ||
| int element_number = 10000; | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Review written by Claude on behalf of @IvanaGyro.) This single Suggest adding sizes that straddle the thresholds, e.g. |
||
| unsigned int dtype = Type_class().cy_typeid(value); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Review written by Claude on behalf of @IvanaGyro.)
All the FP accuracy assertions here use homogeneous (identical) values, the one input distribution where even naive serial summation stays nearly exact. That doesn't demonstrate or guard the actual benefit of pairwise summation — improved accuracy under a large dynamic range of magnitudes.
Consider one heterogeneous-magnitude test for
cytnx_double/cytnx_floatwhere naive accumulation visibly loses precision but pairwise recovers it, e.g. a tensor like[1e16, 1, 1, ..., 1, -1e16](the1s vanish in a naive running sum once it reaches1e16, but the balanced tree keeps them). That both documents intent and locks in the O(log N) error behavior.