Skip to content

refactor(trace): implement the GPU Tensor trace as CUDA kernels#850

Draft
IvanaGyro wants to merge 1 commit into
claude/trace-cpufrom
claude/gallant-franklin-4g4rY
Draft

refactor(trace): implement the GPU Tensor trace as CUDA kernels#850
IvanaGyro wants to merge 1 commit into
claude/trace-cpufrom
claude/gallant-franklin-4g4rY

Conversation

@IvanaGyro
Copy link
Copy Markdown
Collaborator

@IvanaGyro IvanaGyro commented May 29, 2026

Fixes #834.

Stacked on #849 → stride_view → trace-sig-refactor → trace-cpu. This PR's diff is only the GPU CUDA implementation. Each prerequisite PR retargets up the chain as it merges.

Summary

Implements the GPU Tensor trace as native CUDA kernels reading storage directly via Tensor::strides(), replacing the previous Contract-based path that built an identity UniTensor per call.

Change

  • Trace2dKernel<T>: single-block shared-memory tree reduction for the rank-2 path ({n, n} → scalar). The block size is a plain constexpr int kTraceThreadsPerBlock = 512; a static_assert pins the pow-of-2 invariant the tree reduction relies on.
  • TraceNdKernel<T>: one thread per output element for the rank-≥3 path; each thread sums the Ndiag diagonal entries that share its remaining-rank multi-index. Storage offsets for each output element are computed once on the host and shipped to the device via a small cudaMalloc / cudaMemcpy / cudaFree triple (cudaFree synchronizes, so no separate cudaDeviceSynchronize; the 2D kernel runs fully asynchronously).
  • TraceImplGpu<T> is a single templated host helper in an unnamed namespace that derives the same Ndiag / Nelem / accu / remain_rank_id / is_2d params inline from Tn.shape() + Tn.strides() (same shape as the CPU TraceImpl<T>).
  • Type-generic over T using T(0) / +=; complex storage is reinterpret_cast to cuda::std::complex<{float,double}>, which provides device operator+ / zero-construction (CUDA ≥ 12.6).
  • Ndiag == 0 / Nelem == 0 guards skip both the host-side offset materialization and the kernel launch (an <<<0, ...>>> launch would return cudaErrorInvalidConfiguration).

Test plan

  • CPU: openblas-cpu full suite 963/963 passed (the GPU TU is excluded; no CPU regressions from the GPU patch).
  • CPU: mkl-cpu full suite 963/963 passed.
  • GPU CI — the .cu changes can't be compiled locally (no CUDA toolkit here); relying on CI to exercise tests/gpu/.

The CPU+benchmark coverage lands in the prerequisite PR; this PR's correctness is exercised on GPU CI by the same LinalgTraceTest cases (rank-2 path, permuted ranks 2–5, complex+integer dtypes, output rank invariants) the CPU PR introduces.

Draft — opened for review.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces benchmarks for linalg::Trace and refactors both the CPU and GPU implementations of Trace_internal to perform in-place diagonal summation using physical strides, eliminating the need to contract with an identity tensor. The review feedback identifies critical bugs where an empty tensor (Ndiag == 0) causes unsigned underflow and subsequent out-of-bounds memory access in the CPU functions. For the GPU implementation, an empty element count (Nelem == 0) can lead to an invalid CUDA configuration error, and the use of redundant cudaDeviceSynchronize() calls blocks host execution and degrades performance.

Comment thread src/backend/linalg_internal_cpu/Trace_internal.cpp Outdated
Comment thread src/backend/linalg_internal_cpu/Trace_internal.cpp Outdated
Comment thread src/backend/linalg_internal_gpu/cuTrace_internal.cu Outdated
Comment thread src/backend/linalg_internal_gpu/cuTrace_internal.cu Outdated
Comment thread benchmarks/linalg/Trace_bm.cpp
Comment thread benchmarks/linalg/Trace_bm.cpp Outdated
@IvanaGyro IvanaGyro force-pushed the claude/835-pairwise-sum branch from e355381 to 4c76053 Compare May 29, 2026 07:53
@IvanaGyro IvanaGyro force-pushed the claude/gallant-franklin-4g4rY branch from f480e9b to 6a2cce9 Compare May 29, 2026 07:53
Copy link
Copy Markdown
Collaborator Author

@IvanaGyro IvanaGyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Review written by Claude on behalf of @IvanaGyro.)

I verified the central correctness claim: PhysicalStrides reconstructs exactly the row-major-over-invmapper layout that the container layer already uses to address elements (Tensor_impl::get does vec_map(_shape, _invmapper) then row-major over that order), so summing the diagonal in place via strides[ax1] + strides[ax2] is correct, and it's strictly more correct than the old _trace_2d, which assumed contiguous storage (rawdata[i*Ldim + i]) and would have been wrong for a lazily-permuted input. The GPU 2D tree reduction is valid (block dim 512 is a power of two), and the nd offset precomputation mirrors the CPU path. Nice refactor — it removes the inverted layering and the full-tensor contiguous() copy.

A few non-blocking points, left inline:

  1. PhysicalStrides is duplicated byte-for-byte between the CPU and GPU translation units.
  2. extent = (Ndiag - 1) * diag_stride + 1 underflows if Ndiag == 0.
  3. The GPU tree reduction silently assumes _TNinB_TRACE_ is a power of two.

Also a cross-PR note: this PR makes StrideView (from #849) live — the CPU trace now reduces a PairwiseSum over a StrideView temporary. The current call sites are safe because the StrideView temporary outlives the whole PairwiseSum(...) full-expression, but it means the dangling-pointer concern @gemini-code-assist raised on #849's StrideView::Iterator now has a real consumer; worth fixing the iterator (store the base iterator, not a pointer to the view) before more call sites appear in #834 follow-ups.

GPU build + test results (openblas-cuda / mkl-cuda) will follow in a separate comment.

Comment thread src/backend/linalg_internal_cpu/Trace_internal.cpp Outdated
Comment thread src/backend/linalg_internal_cpu/Trace_internal.cpp Outdated
Comment thread src/backend/linalg_internal_gpu/cuTrace_internal.cu Outdated
@IvanaGyro
Copy link
Copy Markdown
Collaborator Author

(Comment written by Claude on behalf of @IvanaGyro.)

Test-coverage assessment. The correctness-relevant paths this PR adds aren't all exercised by committed tests:

  1. Non-contiguous / permuted input — the headline behavior change — is untested. The whole point of PhysicalStrides/invmapper is to trace a lazily-permuted tensor in place, where the old _trace_2d assumed contiguous storage (rawdata[i*Ldim + i]). But DenseUniTensorTest.Trace and gpu_Trace both trace a single loaded, contiguous dense4trtensor. Please add a test that permutes a tensor (so is_contiguous() == false) and traces non-adjacent axes, comparing against the contiguous reference — i.e. T.permute(p).Trace(i,j) vs T.permute(p).contiguous().Trace(i,j). Today only the benchmark touches non-adjacent axes.

  2. Commit the brute-force validation. The description says the index math was validated "against a brute-force trace across ranks 2–5 and swapped axis order" — that is the test this PR needs, but only benchmarks/linalg/Trace_bm.cpp was added, no test file. Please commit that brute-force comparison as a gtest, parameterized over rank and axis order, for both CPU and GPU.

  3. 2D path is not covered. The fixtures trace a rank-4 tensor over (0,3) → the _trace_nd / cuTrace_nd_kernel path. Add a rank-2 trace so _trace_2d and the GPU cuTrace_2d_kernel shared-memory reduction are actually exercised.

  4. Dtype coverage on GPU. Add at least one complex-dtype GPU trace (exercises the reinterpret_cast to cuda::std::complex) and an integer dtype, instead of only the loaded fixture's dtype.

  5. Edge case Ndiag == 0 (zero-length traced axis) — ties to the extent underflow inline comment; add a test asserting the intended behavior (throw, or result 0).

Suggestions rather than blockers — except (1) and (2): the permuted-input path is the core change of this PR and currently has no automated coverage.

@IvanaGyro
Copy link
Copy Markdown
Collaborator Author

(Comment written by Claude on behalf of @IvanaGyro.)

GPU build + test results

Built and tested on two GPUs (GTX 1660 SUPER, RTX 2080 Ti), CUDA 12, conda toolchain.

Builds — both succeeded (rc=0):

  • openblas-cuda — library + test_main + gpu_test_main built clean.
  • mkl-cuda — library + test_main + gpu_test_main built clean.

Tests (release, the authoritative result for this PR):

  • openblas-cuda: 100% — 1678/1678 passed, 0 failed.
  • mkl-cuda: 100% — 1678/1678 passed, 0 failed.
  • The GPU trace path this PR rewrites is exercised and green in both: DenseUniTensorTest.gpu_Trace and BlockUniTensorTest.gpu_Trace passed.

Canonical ASan run (ctest --preset cpu-and-cudadebug-openblas-cuda, USE_DEBUG=ON, AddressSanitizer):

  • 99% — 11 failed / 1678. None are related to this change. The failures are all pre-existing under the debug/ASan build, in unrelated areas:
    Rsvd_truncate.*exp_svals*, Gesvd_truncate.*exp_svals*, Svd_truncate.*exp_svals*, Rsvd.*exp_svals*, SearchTreeTest.BasicSearchOrder2, BlockFermionicUniTensorTest.SaveLoad.
    They abort without a gtest [ FAILED ] marker (classic ASan-abort / debug-assert), and gpu_Trace itself passed under ASan too. I did not run the base branch to formally confirm they pre-date this PR, but the change only touches Trace/Sum and no Trace/Sum test fails.

Conclusion: this PR is green on GPU for both BLAS backends; the only failures are unrelated ASan/debug pre-existing issues.


Separately, while running the suite I noticed the GPU tests are dominated by a handful of extremely slow element-wise *AllShapes all_types/mixed_types cases (the worst ~386 s in release). I filed #851 to track speeding them up — not related to this PR.

@IvanaGyro IvanaGyro force-pushed the claude/835-pairwise-sum branch from 4c76053 to 38fdf4f Compare May 29, 2026 23:12
@IvanaGyro IvanaGyro force-pushed the claude/gallant-franklin-4g4rY branch from 6a2cce9 to 95cb3b0 Compare May 29, 2026 23:13
Comment thread src/backend/linalg_internal_gpu/cuTrace_internal.cu Outdated
Copy link
Copy Markdown
Collaborator Author

@IvanaGyro IvanaGyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Review written by Claude on behalf of @IvanaGyro.)

This round resolves all four of my earlier code comments: PhysicalStrides is hoisted into a public Tensor::strides() (no more CPU/GPU duplication, and the impl matches the invmapper layout Tensor::at uses), the Ndiag == 0 underflow is guarded on both backends, the _TNinB_TRACE_ power-of-two invariant has a static_assert, and the reductions now use the | stride(...) adaptor. Nice. Remaining items are about test coverage — inline.

Comment thread src/Tensor.cpp
Comment thread src/backend/linalg_internal_cpu/Trace_internal.cpp Outdated
@IvanaGyro
Copy link
Copy Markdown
Collaborator Author

(Comment written by Claude on behalf of @IvanaGyro.)

GPU build + test results (updated commit 95cb3b0)

Two GPUs (GTX 1660 SUPER, RTX 2080 Ti), CUDA 12, conda toolchain. Both presets configured with -DRUN_TESTS=ON and built lib + test_main + gpu_test_main.

Builds: both succeeded (rc=0).

Tests (release, ctest):

  • openblas-cuda: 100% — 1683/1683 passed, 0 failed.
  • mkl-cuda: 100% — 1683/1683 passed, 0 failed.
  • The GPU trace path is green on both: DenseUniTensorTest.gpu_Trace and BlockUniTensorTest.gpu_Trace passed.

Notes:

Conclusion: green on GPU for both BLAS backends.

Comment thread benchmarks/linalg/Trace_bm.cpp Outdated
template <class T>
void _trace_nd(Tensor &out, const Tensor &Tn, const cytnx_uint64 &Ndiag,
const cytnx_uint64 &Nelem, const std::vector<cytnx_uint64> &accu,
void _trace_nd(Tensor &out, const Tensor &Tn, cytnx_uint64 Ndiag, cytnx_uint64 Nelem,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should directly return out instead of inputting it. Ndiag, Nelem, accum remain_rank_id, shape can be got from Tn and ax1 and ax2. Remove them. Also remove Ndiag in 2d case.
You may want to rewrite the algorithm.

And also consider that will it be faster with BLAS if the number of remaining dimensions are large.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 5e38cb9. Each Trace_internal_* / cuTrace_internal_* dispatcher now takes only (Tn, ax1, ax2) and returns the result Tensor. A small TraceParams helper (backend/linalg_internal_cpu/trace_dispatch.hpp) derives Ndiag, Nelem, accu, remain_rank_id, the reduced shape, and the is_2d flag once per call from shape+axes, so the dispatchers and the _trace_2d / _trace_nd (and _trace_*_gpu) kernels no longer plumb those individually. The 2D kernels in particular no longer carry a separate Ndiag argument — they read it from the TraceParams. Trace.cpp shrinks to just validation + the dtype-keyed dispatch call.

On the BLAS question: the new benchmark in this PR pits the strided in-place reduction against an actual BLAS matrix-vector (Matmul against vec(I_n)) for the 3D case, and against a BLAS dot product (Vectordot) for the 2D case (and a reshape trick alongside it). At the tested shapes on openblas-cpu the strided path beats both BLAS variants by orders of magnitude — the gather/GEMM/dot setup dominates over the actual reduction. The BLAS routes would only catch up at much larger n; given that, the strided default still seems right. Happy to leave a // TODO and revisit if a profile elsewhere shows the BLAS path winning for a realistic shape.


Generated by Claude Code

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need dispatcher. You can just calculate all info you need from the shape and strides of the input tensor.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 5a843df (CPU) / ba1bb34 (GPU). trace_dispatch.hpp is gone; each Trace_internal_* / cuTrace_internal_* calls a templated TraceImpl<T> (resp. TraceImplGpu<T>) in an unnamed namespace that derives shape / Ndiag / Nelem / accu / remain_rank_id inline from Tn.shape() and Tn.strides() and returns the output Tensor — no separate TraceParams struct or dispatcher header.


Generated by Claude Code

Comment thread src/backend/linalg_internal_cpu/Trace_internal.cpp Outdated
Comment thread benchmarks/linalg/Trace_bm.cpp Outdated
Comment thread benchmarks/linalg/Trace_bm.cpp Outdated
@IvanaGyro IvanaGyro force-pushed the claude/835-pairwise-sum branch from 38fdf4f to 3a5a680 Compare May 30, 2026 09:57
@IvanaGyro IvanaGyro force-pushed the claude/gallant-franklin-4g4rY branch from 95cb3b0 to fd6fd1f Compare May 30, 2026 09:57
cytnx_uint64 Ldim = Tn.shape()[1];
for (cytnx_uint64 i = 0; i < Ndiag; i++) a += rawdata[i * Ldim + i];
out.storage().at<T>(0) = a;
static void _trace_2d(Tensor &out, const Tensor &Tn, const TraceParams &p) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnamed namespace is preferred than static. refer google c++ code style

Because we refactor the whole signature, rename _trace_2d and _trace_nd to follow the style guide

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 5a843df. static is gone; the helper now lives in an unnamed namespace inside cytnx::linalg_internal, and the per-dtype _trace_2d / _trace_nd split was absorbed into a single TraceImpl<T> (PascalCase, Google-style function name).


Generated by Claude Code

} else {
_trace_nd_gpu<cytnx_complex128>(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1,
ax2);
__global__ void cuTrace_2d_kernel(T *out, const T *in, cytnx_uint64 Ndiag,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the name follow the style guide

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in ba1bb34cuTrace_2d_kernelTrace2dKernel, also moved into an unnamed namespace alongside the rest of the host helpers.


Generated by Claude Code

_trace_nd_gpu<cytnx_float>(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, ax2);
}
template <class T>
static void _trace_2d_gpu(Tensor &out, const Tensor &Tn, const TraceParams &p) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the name that following the style guide

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in ba1bb34cuTrace_nd_kernelTraceNdKernel, and the host launcher fold-in (_trace_2d_gpu / _trace_nd_gpu collapsed into TraceImplGpu<T>).


Generated by Claude Code

Comment on lines +120 to +121
TraceParams p = DeriveTraceParams(Tn, ax1, ax2);
Tensor out = MakeTraceOutput(Tn, p);
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need these two

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in ba1bb34. The dispatcher's DeriveTraceParams + MakeTraceOutput calls are gone; TraceImplGpu<T> derives the shape / Ndiag / Nelem inline and allocates the output buffer itself. No separate TraceDispatch wrapper in the GPU TU either.


Generated by Claude Code

Comment thread tests/linalg_test/stride_view_test.cpp Outdated
Comment on lines +11 to +13
using cytnx::linalg_internal::PairwiseSum;
using cytnx::linalg_internal::stride;
using cytnx::linalg_internal::stride_view;
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these using allowed by Google C++ style guide?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved inside the unnamed namespace in 5a843df. Google's style guide allows using-declarations (not using-directives) in .cc files, and explicitly inside an unnamed namespace is the canonical placement — that's the form used in the standard now.


Generated by Claude Code

Comment on lines +13 to +18
using cytnx::cytnx_int32;
using cytnx::cytnx_int64;
using cytnx::cytnx_uint64;
using cytnx::Device;
using cytnx::Tensor;
using cytnx::Type;
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these using allowed by Google Style Guide?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were already inside the unnamed namespace (namespace { using cytnx::Tensor; ... }), which is the placement Google's style guide explicitly allows for using-declarations in .cc files (using-directives are the forbidden form). Left as-is in 5a843df — happy to switch to fully qualified names if you'd still prefer that.


Generated by Claude Code

@IvanaGyro IvanaGyro force-pushed the claude/gallant-franklin-4g4rY branch 3 times, most recently from 47869fa to b295b04 Compare May 31, 2026 01:53
@pcchen pcchen added this to the v1.1.0 milestone May 31, 2026
@IvanaGyro IvanaGyro force-pushed the claude/gallant-franklin-4g4rY branch from b295b04 to cce6d7d Compare May 31, 2026 03:05
@IvanaGyro IvanaGyro changed the title refactor(trace): compute Tensor trace in the container layer (CPU + GPU) refactor(trace): implement the GPU Tensor trace as CUDA kernels May 31, 2026
@IvanaGyro IvanaGyro changed the base branch from claude/835-pairwise-sum to claude/trace-cpu May 31, 2026 03:06
Comment on lines +75 to +76
const cytnx_uint64 ax1 = std::min(a1, a2);
const cytnx_uint64 ax2 = std::max(a1, a2);
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need this.

if (i < Nelem) {
cytnx_uint64 base = offsets[i];
T acc = T(0);
for (cytnx_uint64 d = 0; d < Ndiag; d++) acc += in[base + d * diag_stride];
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not call Trace2dKernel?

tmp %= accu[x];
}
offsets[i] = base;
}
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 109-122 Can calculate offset in the kernel and please precalculate strides[remain_rank_id so that you just need some_new_strides[x].

Trace2dKernel<CudaT><<<1, kTraceThreadsPerBlock>>>(
reinterpret_cast<CudaT *>(out.storage().data()),
reinterpret_cast<const CudaT *>(Tn.storage().data()), Ndiag, diag_stride);
return out;
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compose Tensor until return. same as the nd path.

cudaMemcpyHostToDevice));

cytnx_uint64 NBlocks = Nelem / kTraceThreadsPerBlock;
if (Nelem % kTraceThreadsPerBlock) NBlocks += 1;
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why += 1 is needed? Please give NBlocks and Nelem better name and follow the style guide

Copy link
Copy Markdown
Collaborator Author

@IvanaGyro IvanaGyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Review written by Claude on behalf of @IvanaGyro.)

Replacing the eye+Contract GPU trace with direct kernels is a solid win and mirrors the CPU TraceImpl offset math 1:1, so the two stay easy to keep in sync. The cuda::std::complex mapping for the complex dtypes is the right way to keep the kernels type-generic, and the power-of-two static_assert on the block size correctly guards the tree reduction. I built and ran this on both openblas-cuda and mkl-cuda: DenseUniTensorTest.gpu_Trace and BlockUniTensorTest.gpu_Trace pass on both. Comments inline: an async-launch error-check gap, and a GPU test-coverage gap that I think is the most important follow-up here.

const cytnx_uint64 diag_stride = strides[ax1] + strides[ax2];

if (is_2d) {
Trace2dKernel<CudaT><<<1, kTraceThreadsPerBlock>>>(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Review written by Claude on behalf of @IvanaGyro.)

The 2D path launches the kernel and returns immediately with no error check. Unlike the ND path — where the following checkCudaErrors(cudaFree(d_offsets)) synchronizes and surfaces any launch/exec error — a bad 2D launch (and any async execution error) is silently swallowed until the next unrelated CUDA call, which then reports a confusing error at the wrong site.

Suggest adding checkCudaErrors(cudaGetLastError()); immediately after both kernel launches so a configuration/launch failure is reported at the trace call.

if (i < Nelem) {
cytnx_uint64 base = offsets[i];
T acc = T(0);
for (cytnx_uint64 d = 0; d < Ndiag; d++) acc += in[base + d * diag_stride];
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Review written by Claude on behalf of @IvanaGyro.)

These new kernels are currently only exercised by the existing DenseUniTensorTest.gpu_Trace / BlockUniTensorTest.gpu_Trace, which run a contiguous, fixed shape/dtype. The interesting paths this rewrite introduces are untested on GPU: the strided/permuted input (non-trivial diag_stride via Tn.strides()), the ND kernel with non-adjacent traced axes, complex vs integer dtypes through the cuda::std::complex mapping, the 2D vs ND branch split, and the Ndiag==0/Nelem==0 early-out.

#862 added exactly this matrix for CPU as LinalgTraceTest (permuted ranks 2–5, Double/ComplexDouble/Int32, 2D path, swapped axes). Strong suggestion: parameterize that suite over {Device.cpu, Device.cuda} (guarded by UNI_GPU) and compare the GPU result against the contiguous-CPU reference, so the kernels' stride handling and per-dtype accumulation are actually covered. Separately, note the kernels accumulate the diagonal serially (acc += ...) rather than the pairwise sum the CPU path now uses — fine for correctness, but a large-Ndiag accuracy comparison would document the tolerance gap intentionally.

@IvanaGyro IvanaGyro force-pushed the claude/gallant-franklin-4g4rY branch from cce6d7d to e71c3fc Compare May 31, 2026 08:40
@IvanaGyro IvanaGyro force-pushed the claude/gallant-franklin-4g4rY branch 2 times, most recently from 12b9474 to 9bb08b3 Compare June 2, 2026 02:33
Implements the GPU Tensor trace as native CUDA kernels that read storage
directly via Tensor::strides(), replacing the previous Contract-based path that
built an identity UniTensor per call.

* Trace2dKernel<T>: single-block shared-memory tree reduction for the rank-2
  path. The block size is a constexpr int kTraceThreadsPerBlock = 512; a
  static_assert pins the power-of-two invariant the tree reduction relies on.
* TraceNdKernel<T>: one thread per output element for the rank->=3 path. Each
  thread decodes its own flat index into the remaining-rank multi-index and
  accumulates the input base offset from the surviving axes' input strides, then
  sums the Ndiag diagonal entries. Decoding on the device means the host ships
  only the two out_rank-sized layout arrays (shape + per-axis input stride)
  rather than an Nelem-sized offset table.
* TraceImplGpu<T> derives the same Ndiag / Nelem / remaining-axis layout inline
  from Tn.shape() + Tn.strides() as the CPU TraceImpl<T>.
* Type-generic over T using T(0) / +=; complex storage is reinterpret_cast to
  cuda::std::complex<{float,double}>, which provides device operator+ and
  zero-construction.
* Ndiag == 0 / Nelem == 0 guards skip the kernel launch (an <<<0, ...>>> launch
  would return cudaErrorInvalidConfiguration).
* checkCudaErrors(cudaGetLastError()) after each launch surfaces a launch or
  configuration failure at the trace call instead of at the next unrelated CUDA
  call. The trailing cudaFree synchronizes the device, so the ND kernel's reads
  of the layout arrays are complete before the buffers are released.

Co-authored-by: Claude <noreply@anthropic.com>
@IvanaGyro IvanaGyro force-pushed the claude/gallant-franklin-4g4rY branch from 9bb08b3 to 8f14590 Compare June 2, 2026 03:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Refactor Tensor trace: low-level backend depends on UniTensor (API encapsulation violation)

2 participants