diff --git a/benchmarks/CMakeLists.txt b/benchmarks/CMakeLists.txt index df7b0bedc..e9143302b 100644 --- a/benchmarks/CMakeLists.txt +++ b/benchmarks/CMakeLists.txt @@ -15,6 +15,7 @@ add_executable( linalg/Directsum_bm.cpp linalg/Svd_bm.cpp linalg/Svd_truncate_bm.cpp + linalg/Trace_bm.cpp linalg/Vectordot_bm.cpp linalg/Lanczos_bm.cpp linalg/linalg_basic_bm.cpp diff --git a/benchmarks/linalg/Trace_bm.cpp b/benchmarks/linalg/Trace_bm.cpp new file mode 100644 index 000000000..177307c2c --- /dev/null +++ b/benchmarks/linalg/Trace_bm.cpp @@ -0,0 +1,301 @@ +#include + +#include +#include +#include +#include +#include + +#include "cytnx.hpp" + +// Benchmarks for linalg::Trace. +// +// Two shapes are exercised: +// +// * 3D path: a contiguous {n, middle, n} tensor traced over axes 0 and 2. +// The two traced axes are not adjacent in storage, so the diagonal stride +// (strides[0] + strides[2] = middle * n + 1) is large; each iteration is a +// genuine higher-rank trace rather than a simple 2D matrix trace. +// * 2D path: a contiguous {n, n} tensor traced over axes 0 and 1, exercising +// the _trace_2d / cuTrace_2d_kernel branch. +// +// For each shape, the in-place strided reduction is compared against a +// "collect the traced elements contiguously and reduce them" baseline that +// avoids reading the diagonal with a large stride: +// +// * 3D matvec baseline: tr(A)[m] = , so stacking +// the middle outputs becomes a {middle, n*n} @ {n*n, 1} GEMM call against +// vec(I_n) -- a BLAS matrix-vector multiplication. +// * 2D vecdot baseline: tr(A) = -- a BLAS vector-vector +// dot product on the n*n-element flattenings of I and A. +// * 2D reshape trick: drop the last diagonal entry A[n-1, n-1], view the +// remaining n*n - 1 entries as {n-1, n+1}, permute + contiguous so the +// first column (which holds diag[0..n-2]) becomes a contiguous row, then +// reduce that row over the raw buffer and add the saved last entry. +// +// Every variant runs once before the timing loop and is checked against +// linalg::Trace on the same input, so a wrong baseline would fail loudly +// instead of producing fast-but-meaningless numbers. + +namespace BMTest_Trace { + + // Maps a benchmarked dtype enum to its C++ storage type, so a baseline can + // read the raw buffer directly. Only the dtypes the 2D reshape trick is + // registered for need an entry. + template + struct BmCType; + template <> + struct BmCType { + using type = cytnx::cytnx_double; + }; + template <> + struct BmCType { + using type = cytnx::cytnx_complex128; + }; + + // Aborts the benchmark if `candidate` differs from `reference` by more than + // a small absolute tolerance. Both tensors must have the same dtype, and + // (after broadcasting via Tensor's arithmetic) be element-wise comparable. + static void VerifyAgainstTrace(const cytnx::Tensor& reference, const cytnx::Tensor& candidate, + const char* variant_name) { + auto diff = (candidate - reference).Abs(); + const double max_err = diff.Max().item(); + if (!(max_err < 1e-6)) { + std::cerr << "[Trace benchmark] variant \"" << variant_name + << "\" disagrees with linalg::Trace; max abs error = " << max_err << "\n"; + std::abort(); + } + } + + template + static void BM_Trace_Strided_3D_Template(benchmark::State& state) { + const auto n = state.range(0); + const auto middle = state.range(1); + const auto tensor = cytnx::random::random_tensor({n, middle, n}, -1.0, 1.0, device, 0, dtype); + const auto reference = cytnx::linalg::Trace(tensor, 0, 2); + VerifyAgainstTrace(reference, cytnx::linalg::Trace(tensor, 0, 2), "Strided_3D"); + for (auto _ : state) { + auto result = cytnx::linalg::Trace(tensor, 0, 2); + benchmark::DoNotOptimize(result); + } + state.SetItemsProcessed(static_cast(state.iterations()) * n * n * middle); + } + + // ND matvec baseline: tr(A)[m] = . Stacking the + // middle outputs becomes a {middle, n*n} @ {n*n, 1} GEMM call against vec(I_n). + template + static void BM_Trace_Matvec_3D_Template(benchmark::State& state) { + const auto n = state.range(0); + const auto middle = state.range(1); + const auto tensor = cytnx::random::random_tensor({n, middle, n}, -1.0, 1.0, device, 0, dtype); + const auto vec_I = cytnx::eye(n, dtype, device).reshape({n * n, 1}); + const auto reference = cytnx::linalg::Trace(tensor, 0, 2); + auto compute = [&]() { + auto packed = tensor.permute({1, 0, 2}).contiguous().reshape({middle, n * n}); + return cytnx::linalg::Matmul(packed, vec_I).reshape({middle}); + }; + VerifyAgainstTrace(reference, compute(), "Matvec_3D"); + for (auto _ : state) { + auto result = compute(); + benchmark::DoNotOptimize(result); + } + state.SetItemsProcessed(static_cast(state.iterations()) * n * n * middle); + } + + template + static void BM_Trace_Strided_2D_Template(benchmark::State& state) { + const auto n = state.range(0); + const auto tensor = cytnx::random::random_tensor({n, n}, -1.0, 1.0, device, 0, dtype); + const auto reference = cytnx::linalg::Trace(tensor, 0, 1); + VerifyAgainstTrace(reference, cytnx::linalg::Trace(tensor, 0, 1), "Strided_2D"); + for (auto _ : state) { + auto result = cytnx::linalg::Trace(tensor, 0, 1); + benchmark::DoNotOptimize(result); + } + state.SetItemsProcessed(static_cast(state.iterations()) * n * n); + } + + // 2D vector-dot baseline: tr(A) = as a BLAS dot product + // on the n*n-element flattenings. + template + static void BM_Trace_Vecdot_2D_Template(benchmark::State& state) { + const auto n = state.range(0); + const auto tensor = cytnx::random::random_tensor({n, n}, -1.0, 1.0, device, 0, dtype); + const auto vec_I = cytnx::eye(n, dtype, device).reshape({n * n}); + const auto reference = cytnx::linalg::Trace(tensor, 0, 1); + auto compute = [&]() { + auto vec_A = tensor.reshape({n * n}); + return cytnx::linalg::Vectordot(vec_I, vec_A).reshape({1}); + }; + VerifyAgainstTrace(reference, compute(), "Vecdot_2D"); + for (auto _ : state) { + auto result = compute(); + benchmark::DoNotOptimize(result); + } + state.SetItemsProcessed(static_cast(state.iterations()) * n * n); + } + + // 2D reshape trick: drop A[n-1, n-1], view the remaining n*n - 1 elements as + // {n-1, n+1}, permute + contiguous so the first column (which holds + // diag[0..n-2]) becomes a contiguous row, reduce that row, and add the saved + // last element. + // + // To keep the bulk data movement honest, the {n-1, n+1} view reuses the + // source storage directly: shrinking its size to (n-1)*(n+1) is a no-realloc + // metadata change (Storage::resize only reallocates when growing past the + // capacity), so the permute -> contiguous gather is the only copy the trick + // performs. The size is restored to n*n afterwards (and the dropped corner + // written back, since a grow-resize zero-fills any tail it has to reallocate) + // so the source tensor stays intact across iterations. + // + // The row reduction is device-specific: on CPU the contiguous row is summed + // over the raw buffer (no linalg::Sum / Accessor allocations on top of the + // gather); on GPU the raw buffer lives in device memory and cannot be read + // host-side, so the row is wrapped as a {n-1} tensor over the resized gather + // storage and reduced with linalg::Sum, and the corner is taken as a {1} + // tensor so the final add stays on the device. + template + static void BM_Trace_Reshape_2D_Template(benchmark::State& state) { + using T = typename BmCType::type; + const auto n = state.range(0); + const auto tensor = cytnx::random::random_tensor({n, n}, -1.0, 1.0, device, 0, dtype); + const auto reference = cytnx::linalg::Trace(tensor, 0, 1); + auto compute = [&]() { + auto& storage = tensor.storage(); + if constexpr (device == cytnx::Device_class::cpu) { + const T last = tensor.at( + {static_cast(n - 1), static_cast(n - 1)}); + storage.resize((n - 1) * (n + 1)); + auto view = cytnx::Tensor::from_storage(storage); + view.reshape_({n - 1, n + 1}); + auto packed = view.permute({1, 0}).contiguous(); + storage.resize(n * n); // restore the source storage size + storage.at((n * n) - 1) = last; // and recover the dropped corner + + const T* row = packed.storage().data(); + T sum = T(0); + for (cytnx::cytnx_int64 k = 0; k < n - 1; ++k) sum += row[k]; + sum += last; + + auto out = cytnx::Tensor({static_cast(1)}, dtype, device); + out.storage().at(0) = sum; + return out; + } else { + // Read the corner as a {1} device tensor before the storage shrinks. + auto last = tensor.reshape({n * n}).get({cytnx::Accessor((n * n) - 1)}); + storage.resize((n - 1) * (n + 1)); + auto view = cytnx::Tensor::from_storage(storage); + view.reshape_({n - 1, n + 1}); + auto packed = view.permute({1, 0}).contiguous(); + storage.resize(n * n); // restore the source storage size + + auto row_storage = packed.storage(); + row_storage.resize(n - 1); // first contiguous row = diag[0..n-2] + auto row = cytnx::Tensor::from_storage(row_storage); + return (cytnx::linalg::Sum(row) + last).reshape({1}); + } + }; + VerifyAgainstTrace(reference, compute(), "Reshape_2D"); + for (auto _ : state) { + auto result = compute(); + benchmark::DoNotOptimize(result); + } + state.SetItemsProcessed(static_cast(state.iterations()) * n * n); + } + +#define REGISTER_TRACE_3D_BENCHMARK(TypeName, TypeEnum) \ + BENCHMARK_TEMPLATE(BM_Trace_Strided_3D_Template, TypeEnum) \ + ->Name("BM_Trace_Strided_3D_" #TypeName) \ + ->Args({64, 64}) \ + ->Args({256, 64}) \ + ->Args({1024, 16}) \ + ->Args({2048, 16}) \ + ->Args({4096, 8}) \ + ->Unit(benchmark::kMicrosecond); \ + BENCHMARK_TEMPLATE(BM_Trace_Matvec_3D_Template, TypeEnum) \ + ->Name("BM_Trace_Matvec_3D_" #TypeName) \ + ->Args({64, 64}) \ + ->Args({256, 64}) \ + ->Args({1024, 16}) \ + ->Args({2048, 16}) \ + ->Args({4096, 8}) \ + ->Unit(benchmark::kMicrosecond); + +#define REGISTER_TRACE_2D_BENCHMARK(TypeName, TypeEnum) \ + BENCHMARK_TEMPLATE(BM_Trace_Strided_2D_Template, TypeEnum) \ + ->Name("BM_Trace_Strided_2D_" #TypeName) \ + ->Args({64}) \ + ->Args({256}) \ + ->Args({1024}) \ + ->Args({4096}) \ + ->Args({8192}) \ + ->Unit(benchmark::kMicrosecond); \ + BENCHMARK_TEMPLATE(BM_Trace_Vecdot_2D_Template, TypeEnum) \ + ->Name("BM_Trace_Vecdot_2D_" #TypeName) \ + ->Args({64}) \ + ->Args({256}) \ + ->Args({1024}) \ + ->Args({4096}) \ + ->Args({8192}) \ + ->Unit(benchmark::kMicrosecond); \ + BENCHMARK_TEMPLATE(BM_Trace_Reshape_2D_Template, TypeEnum) \ + ->Name("BM_Trace_Reshape_2D_" #TypeName) \ + ->Args({64}) \ + ->Args({256}) \ + ->Args({1024}) \ + ->Args({4096}) \ + ->Args({8192}) \ + ->Unit(benchmark::kMicrosecond); + + REGISTER_TRACE_3D_BENCHMARK(Double, cytnx::Type.Double) + REGISTER_TRACE_3D_BENCHMARK(ComplexDouble, cytnx::Type.ComplexDouble) + REGISTER_TRACE_2D_BENCHMARK(Double, cytnx::Type.Double) + REGISTER_TRACE_2D_BENCHMARK(ComplexDouble, cytnx::Type.ComplexDouble) + +#ifdef UNI_GPU + #define REGISTER_GPU_TRACE_3D_BENCHMARK(TypeName, TypeEnum) \ + BENCHMARK_TEMPLATE(BM_Trace_Strided_3D_Template, TypeEnum, cytnx::Device.cuda) \ + ->Name("BM_gpu_Trace_Strided_3D_" #TypeName) \ + ->Args({256, 64}) \ + ->Args({1024, 16}) \ + ->Args({2048, 16}) \ + ->Args({4096, 8}) \ + ->Unit(benchmark::kMicrosecond); \ + BENCHMARK_TEMPLATE(BM_Trace_Matvec_3D_Template, TypeEnum, cytnx::Device.cuda) \ + ->Name("BM_gpu_Trace_Matvec_3D_" #TypeName) \ + ->Args({256, 64}) \ + ->Args({1024, 16}) \ + ->Args({2048, 16}) \ + ->Args({4096, 8}) \ + ->Unit(benchmark::kMicrosecond); + + #define REGISTER_GPU_TRACE_2D_BENCHMARK(TypeName, TypeEnum) \ + BENCHMARK_TEMPLATE(BM_Trace_Strided_2D_Template, TypeEnum, cytnx::Device.cuda) \ + ->Name("BM_gpu_Trace_Strided_2D_" #TypeName) \ + ->Args({256}) \ + ->Args({1024}) \ + ->Args({4096}) \ + ->Args({8192}) \ + ->Unit(benchmark::kMicrosecond); \ + BENCHMARK_TEMPLATE(BM_Trace_Vecdot_2D_Template, TypeEnum, cytnx::Device.cuda) \ + ->Name("BM_gpu_Trace_Vecdot_2D_" #TypeName) \ + ->Args({256}) \ + ->Args({1024}) \ + ->Args({4096}) \ + ->Args({8192}) \ + ->Unit(benchmark::kMicrosecond); \ + BENCHMARK_TEMPLATE(BM_Trace_Reshape_2D_Template, TypeEnum, cytnx::Device.cuda) \ + ->Name("BM_gpu_Trace_Reshape_2D_" #TypeName) \ + ->Args({256}) \ + ->Args({1024}) \ + ->Args({4096}) \ + ->Args({8192}) \ + ->Unit(benchmark::kMicrosecond); + + REGISTER_GPU_TRACE_3D_BENCHMARK(Double, cytnx::Type.Double) + REGISTER_GPU_TRACE_3D_BENCHMARK(ComplexDouble, cytnx::Type.ComplexDouble) + REGISTER_GPU_TRACE_2D_BENCHMARK(Double, cytnx::Type.Double) + REGISTER_GPU_TRACE_2D_BENCHMARK(ComplexDouble, cytnx::Type.ComplexDouble) +#endif // UNI_GPU + +} // namespace BMTest_Trace diff --git a/include/Tensor.hpp b/include/Tensor.hpp index 11fd45bf7..3e7a4c334 100644 --- a/include/Tensor.hpp +++ b/include/Tensor.hpp @@ -596,6 +596,17 @@ namespace cytnx { */ const std::vector &shape() const { return this->_impl->shape(); } + /** + @brief the storage strides of the Tensor + @return [std::vector] for each logical axis, the distance in the + underlying storage between consecutive elements along that axis + @details cytnx tensors store a dense permutation of their logical axes, so the + stride of every axis is well defined (this is the layout Tensor::at indexes + through). For a contiguous tensor these are the row-major strides; for a + permuted (non-contiguous) view they reflect the permuted memory order. + */ + std::vector strides() const; + /** @brief the rank of the Tensor @return [cytnx_uint64] the rank of the Tensor diff --git a/src/Tensor.cpp b/src/Tensor.cpp index 7d5d51544..bd595c787 100644 --- a/src/Tensor.cpp +++ b/src/Tensor.cpp @@ -907,6 +907,22 @@ namespace cytnx { return is(this->_impl->storage(), rhs.storage()); } + std::vector Tensor::strides() const { + // The storage is laid out contiguously in memory order; _invmapper[i] gives + // the logical axis sitting at memory position i (innermost last). The stride + // of a logical axis is the product of the memory-order extents inside it. + const std::vector &shape = this->_impl->shape(); + const std::vector &invmapper = this->_impl->invmapper(); + const cytnx_uint64 rank = shape.size(); + std::vector out(rank); + cytnx_uint64 step = 1; + for (cytnx_int64 i = static_cast(rank) - 1; i >= 0; i--) { + out[invmapper[i]] = step; + step *= shape[invmapper[i]]; + } + return out; + } + //=========================== // Tensor am Tproxy Tensor operator+(const Tensor &lhs, const Tensor::Tproxy &rhs) { diff --git a/src/backend/linalg_internal_cpu/Trace_internal.cpp b/src/backend/linalg_internal_cpu/Trace_internal.cpp index 858b8d3bc..409c454a7 100644 --- a/src/backend/linalg_internal_cpu/Trace_internal.cpp +++ b/src/backend/linalg_internal_cpu/Trace_internal.cpp @@ -1,192 +1,125 @@ #include "Trace_internal.hpp" +#include "Tensor.hpp" +#include "backend/Storage.hpp" #include "cytnx_error.hpp" -#include "backend/lapack_wrapper.hpp" -#include "Generator.hpp" -#include "utils/utils.hpp" +#include "backend/linalg_internal_cpu/pairwise_sum.hpp" +#include "backend/linalg_internal_cpu/stride_view.hpp" -#include "UniTensor.hpp" +#include +#include #include namespace cytnx { namespace linalg_internal { + namespace { + + template + Tensor TraceImpl(const Tensor &Tn, cytnx_uint64 a1, cytnx_uint64 a2) { + const cytnx_uint64 ax1 = std::min(a1, a2); + const cytnx_uint64 ax2 = std::max(a1, a2); + const auto &shape_in = Tn.shape(); + const cytnx_uint64 Ndiag = shape_in[ax1]; + + std::vector out_shape; + std::vector remain_rank_id; + for (cytnx_uint64 i = 0; i < shape_in.size(); ++i) { + if (i != ax1 && i != ax2) { + out_shape.push_back(static_cast(shape_in[i])); + remain_rank_id.push_back(i); + } + } + cytnx_uint64 Nelem = 1; + for (auto d : out_shape) Nelem *= static_cast(d); + const bool is_2d = out_shape.empty(); + + // Fill a flat result Storage, then compose the output Tensor from it; the + // 2D trace produces a single element, the ND trace one element per + // remaining-rank multi-index. + Storage out_storage(is_2d ? cytnx_uint64{1} : Nelem, Tn.dtype(), Tn.device()); + if (Ndiag == 0 || Nelem == 0) { + out_storage.set_zeros(); + Tensor out = Tensor::from_storage(out_storage); + if (!is_2d) out.reshape_(out_shape); + return out; + } + + const std::vector strides = Tn.strides(); + const cytnx_uint64 diag_stride = strides[ax1] + strides[ax2]; + const cytnx_uint64 extent = (Ndiag - 1) * diag_stride + 1; + const T *data = Tn.storage().data(); + T *out_data = out_storage.data(); + + if (is_2d) { + out_data[0] = PairwiseSum(std::span(data, extent) | stride(diag_stride)); + return Tensor::from_storage(out_storage); + } + + // Input stride for each surviving (output) axis, so the hot loop indexes a + // flat array instead of going through remain_rank_id on every step. + std::vector out_strides(out_shape.size()); + for (cytnx_uint64 x = 0; x < out_shape.size(); ++x) + out_strides[x] = strides[remain_rank_id[x]]; + + // Walk the output elements in row-major order, carrying the input base + // offset on an odometer: each step bumps the last axis index (carrying into + // earlier axes on wrap) and adjusts base by the affected axes' strides. This + // avoids the per-element division and modulo of decoding the flat index, and + // needs no precomputed row-major accumulators. + std::vector index(out_shape.size(), 0); + cytnx_uint64 base = 0; + for (cytnx_uint64 i = 0; i < Nelem; ++i) { + out_data[i] = PairwiseSum(std::span(data + base, extent) | stride(diag_stride)); + for (cytnx_uint64 x = out_shape.size(); x-- > 0;) { + if (++index[x] < static_cast(out_shape[x])) { + base += out_strides[x]; + break; + } + index[x] = 0; + base -= (static_cast(out_shape[x]) - 1) * out_strides[x]; + } + } + Tensor out = Tensor::from_storage(out_storage); + out.reshape_(out_shape); + return out; + } - template - void _trace_2d(Tensor &out, const Tensor &Tn, const cytnx_uint64 &Ndiag) { - T a = 0; - T *rawdata = Tn.storage().data(); - cytnx_uint64 Ldim = Tn.shape()[1]; - for (cytnx_uint64 i = 0; i < Ndiag; i++) a += rawdata[i * Ldim + i]; - out.storage().at(0) = a; - } - - template - void _trace_nd(Tensor &out, const Tensor &Tn, const cytnx_uint64 &Ndiag, - const cytnx_uint64 &Nelem, const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2) { - cytnx::UniTensor I_UT = cytnx::UniTensor::eye(Ndiag, {}, true, Tn.dtype(), Tn.device()); - - UniTensor UTn = UniTensor(Tn, false, 2); - I_UT.relabel_({UTn._impl->_labels[ax1], UTn._impl->_labels[ax2]}); - - out = Contract(I_UT, UTn).get_block_(); - - // std::vector indexer(Tn.shape().size(), 0); - // cytnx_uint64 tmp; - // for (cytnx_uint64 i = 0; i < Nelem; i++) { - // tmp = i; - // // calculate indexer - // for (int x = 0; x < shape.size(); x++) { - // indexer[remain_rank_id[x]] = cytnx_uint64(tmp / accu[x]); - // tmp %= accu[x]; - // } + } // namespace - // for (cytnx_uint64 d = 0; d < Ndiag; d++) { - // indexer[ax1] = indexer[ax2] = d; - // out.storage().at(i) += Tn.at(indexer); - // } - // } + Tensor Trace_internal_cd(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2) { + return TraceImpl(Tn, ax1, ax2); } - - void Trace_internal_cd(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2) { - if (is_2d) { - _trace_2d(out, Tn, Ndiag); - } else { - _trace_nd(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, ax2); - } + Tensor Trace_internal_cf(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2) { + return TraceImpl(Tn, ax1, ax2); } - - void Trace_internal_cf(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2) { - if (is_2d) { - _trace_2d(out, Tn, Ndiag); - } else { - _trace_nd(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, ax2); - } + Tensor Trace_internal_d(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2) { + return TraceImpl(Tn, ax1, ax2); } - - void Trace_internal_d(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2) { - if (is_2d) { - _trace_2d(out, Tn, Ndiag); - } else { - _trace_nd(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, ax2); - } + Tensor Trace_internal_f(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2) { + return TraceImpl(Tn, ax1, ax2); } - - void Trace_internal_f(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2) { - if (is_2d) { - _trace_2d(out, Tn, Ndiag); - } else { - _trace_nd(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, ax2); - } + Tensor Trace_internal_u64(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2) { + return TraceImpl(Tn, ax1, ax2); } - - void Trace_internal_u64(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2) { - if (is_2d) { - _trace_2d(out, Tn, Ndiag); - } else { - _trace_nd(out, Tn, Ndiag, Nelem, accu, remain_rank_id, shape, ax1, ax2); - } + Tensor Trace_internal_i64(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2) { + return TraceImpl(Tn, ax1, ax2); } - - void Trace_internal_i64(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2) { - if (is_2d) { - _trace_2d(out, tn, ndiag); - } else { - _trace_nd(out, tn, ndiag, nelem, accu, remain_rank_id, shape, ax1, ax2); - } + Tensor Trace_internal_u32(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2) { + return TraceImpl(Tn, ax1, ax2); } - - void Trace_internal_u32(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2) { - if (is_2d) { - _trace_2d(out, tn, ndiag); - } else { - _trace_nd(out, tn, ndiag, nelem, accu, remain_rank_id, shape, ax1, ax2); - } + Tensor Trace_internal_i32(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2) { + return TraceImpl(Tn, ax1, ax2); } - - void Trace_internal_i32(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2) { - if (is_2d) { - _trace_2d(out, tn, ndiag); - } else { - _trace_nd(out, tn, ndiag, nelem, accu, remain_rank_id, shape, ax1, ax2); - } - } - - void Trace_internal_u16(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2) { - if (is_2d) { - _trace_2d(out, tn, ndiag); - } else { - _trace_nd(out, tn, ndiag, nelem, accu, remain_rank_id, shape, ax1, ax2); - } + Tensor Trace_internal_u16(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2) { + return TraceImpl(Tn, ax1, ax2); } - - void Trace_internal_i16(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2) { - if (is_2d) { - _trace_2d(out, tn, ndiag); - } else { - _trace_nd(out, tn, ndiag, nelem, accu, remain_rank_id, shape, ax1, ax2); - } + Tensor Trace_internal_i16(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2) { + return TraceImpl(Tn, ax1, ax2); } - void Trace_internal_b(const bool &is_2d, Tensor &out, const Tensor &tn, - const cytnx_uint64 &ndiag, const cytnx_uint64 &nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2) { + Tensor Trace_internal_b(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2) { cytnx_error_msg(true, "[internal][Trace] bool is not available. %s", "\n"); + return Tensor(); } } // namespace linalg_internal diff --git a/src/backend/linalg_internal_cpu/Trace_internal.hpp b/src/backend/linalg_internal_cpu/Trace_internal.hpp index 68c7720c5..2c033c465 100644 --- a/src/backend/linalg_internal_cpu/Trace_internal.hpp +++ b/src/backend/linalg_internal_cpu/Trace_internal.hpp @@ -12,82 +12,17 @@ namespace cytnx { namespace linalg_internal { - void Trace_internal_cd(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void Trace_internal_cf(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void Trace_internal_d(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void Trace_internal_f(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void Trace_internal_u64(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void Trace_internal_i64(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void Trace_internal_u32(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void Trace_internal_i32(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void Trace_internal_u16(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void Trace_internal_i16(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void Trace_internal_b(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); + Tensor Trace_internal_cd(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor Trace_internal_cf(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor Trace_internal_d(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor Trace_internal_f(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor Trace_internal_u64(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor Trace_internal_i64(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor Trace_internal_u32(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor Trace_internal_i32(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor Trace_internal_u16(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor Trace_internal_i16(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor Trace_internal_b(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); } // namespace linalg_internal diff --git a/src/backend/linalg_internal_gpu/cuTrace_internal.hpp b/src/backend/linalg_internal_gpu/cuTrace_internal.hpp index 13fc06fd9..cf0ea86db 100644 --- a/src/backend/linalg_internal_gpu/cuTrace_internal.hpp +++ b/src/backend/linalg_internal_gpu/cuTrace_internal.hpp @@ -12,82 +12,17 @@ namespace cytnx { namespace linalg_internal { - void cuTrace_internal_cd(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void cuTrace_internal_cf(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void cuTrace_internal_d(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void cuTrace_internal_f(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void cuTrace_internal_u64(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void cuTrace_internal_i64(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void cuTrace_internal_u32(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void cuTrace_internal_i32(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void cuTrace_internal_u16(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void cuTrace_internal_i16(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); - - void cuTrace_internal_b(const bool &is_2d, Tensor &out, const Tensor &Tn, - const cytnx_uint64 &Ndiag, const cytnx_uint64 &Nelem, - const std::vector &accu, - const std::vector &remain_rank_id, - const std::vector &shape, const cytnx_uint64 &ax1, - const cytnx_uint64 &ax2); + Tensor cuTrace_internal_cd(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor cuTrace_internal_cf(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor cuTrace_internal_d(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor cuTrace_internal_f(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor cuTrace_internal_u64(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor cuTrace_internal_i64(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor cuTrace_internal_u32(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor cuTrace_internal_i32(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor cuTrace_internal_u16(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor cuTrace_internal_i16(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); + Tensor cuTrace_internal_b(const Tensor &Tn, cytnx_uint64 ax1, cytnx_uint64 ax2); } // namespace linalg_internal diff --git a/src/backend/linalg_internal_interface.hpp b/src/backend/linalg_internal_interface.hpp index 5f5b5b696..73a948bd7 100644 --- a/src/backend/linalg_internal_interface.hpp +++ b/src/backend/linalg_internal_interface.hpp @@ -178,11 +178,7 @@ namespace cytnx { typedef void (*Sumfunc_oii)(boost::intrusive_ptr &, const boost::intrusive_ptr &, const cytnx_uint64 &); - typedef void (*Tracefunc_oii)(const bool &, Tensor &, const Tensor &, const cytnx_uint64 &, - const cytnx_uint64 &, const std::vector &, - const std::vector &, - const std::vector &, const cytnx_uint64 &, - const cytnx_uint64 &); + typedef Tensor (*Tracefunc_oii)(const Tensor &, cytnx_uint64, cytnx_uint64); typedef void (*Tensordotfunc_oii)(Tensor &out, const Tensor &Lin, const Tensor &Rin, const std::vector &idxl, diff --git a/src/linalg/Trace.cpp b/src/linalg/Trace.cpp index 5b481c652..4bd0c9e08 100644 --- a/src/linalg/Trace.cpp +++ b/src/linalg/Trace.cpp @@ -45,73 +45,17 @@ namespace cytnx { "shape(%d) = %ld and shape(%d) = %ld do not match.%s", axisA, Tn.shape()[axisA], axisB, Tn.shape()[axisB], "\n"); - cytnx_uint64 ax1, ax2; - if (axisA < axisB) { - ax1 = axisA; - ax2 = axisB; - } else { - ax1 = axisB; - ax2 = axisA; + if (Tn.device() == Device.cpu) { + return linalg_internal::lii.Trace_ii[Tn.dtype()](Tn, axisA, axisB); } - - // 1) get redundant rank: - vector shape(Tn.shape().begin(), Tn.shape().end()); - vector accu; - shape.erase(shape.begin() + ax2); - shape.erase(shape.begin() + ax1); - // 2) get output element size. - cytnx_uint64 Nelem = 1; - for (int i = 0; i < shape.size(); i++) Nelem *= shape[i]; - - Tensor out = Tensor({Nelem}, Tn.dtype(), Tn.device()); - out.storage().set_zeros(); - - if (shape.size() == 0) { - // 2d - if (Tn.device() == Device.cpu) - linalg_internal::lii.Trace_ii[Tn.dtype()](true, out, Tn, Tn.shape()[ax1], 0, {}, {}, {}, - 0, - 0); // only the first 4 args will be used. - else { #ifdef UNI_GPU - checkCudaErrors(cudaSetDevice(Tn.device())); - linalg_internal::lii.cuTrace_ii[Tn.dtype()](true, out, Tn, Tn.shape()[ax1], 0, {}, {}, {}, - 0, - 0); // only the first 4 args will be used. + checkCudaErrors(cudaSetDevice(Tn.device())); + return linalg_internal::lii.cuTrace_ii[Tn.dtype()](Tn, axisA, axisB); #else - cytnx_error_msg(true, "[Trace] fatal error,%s", - "try to call the gpu section without CUDA support.\n"); - return out; + cytnx_error_msg(true, "[Trace] fatal error,%s", + "try to call the gpu section without CUDA support.\n"); + return Tensor(); #endif - } - } else { - // nd - vector remain_rank_id; - vector accu(shape.size()); - accu.back() = 1; - for (int i = shape.size() - 1; i > 0; i--) accu[i - 1] = accu[i] * shape[i]; - - for (cytnx_uint64 i = 0; i < Tn.shape().size(); i++) { - if (i != ax1 && i != ax2) remain_rank_id.push_back(i); - } - if (Tn.device() == Device.cpu) - linalg_internal::lii.Trace_ii[Tn.dtype()](false, out, Tn, Tn.shape()[ax1], Nelem, accu, - remain_rank_id, shape, ax1, ax2); - else { - #ifdef UNI_GPU - checkCudaErrors(cudaSetDevice(Tn.device())); - linalg_internal::lii.cuTrace_ii[Tn.dtype()](false, out, Tn, Tn.shape()[ax1], Nelem, accu, - remain_rank_id, shape, ax1, ax2); - #else - cytnx_error_msg(true, "[Trace] fatal error,%s", - "try to call the gpu section without CUDA support.\n"); - return out; - #endif - } - out.reshape_(shape); - } - - return out; } } // namespace linalg diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7509bfaa7..da7b356cb 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -22,6 +22,7 @@ add_executable( DenseUniTensor_test.cpp Accessor_test.cpp Tensor_test.cpp + Tensor_strides_test.cpp Storage_test.cpp search_tree_test.cpp utils_test/vec_concatenate.cpp @@ -42,6 +43,7 @@ add_executable( linalg_test/Svd_truncate_test.cpp linalg_test/Gesvd_truncate_test.cpp linalg_test/Rsvd_truncate_test.cpp + linalg_test/Trace_test.cpp linalg_test/linalg_test.cpp linalg_test/stride_view_test.cpp linalg_test/sum_test.cpp @@ -80,7 +82,10 @@ add_link_options(-fsanitize=address) #target_link_libraries(test_main PUBLIC "-lgcov --coverage") include(GoogleTest) gtest_discover_tests(test_main - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + # The test binary takes longer to enumerate than the + # default 5s under MKL + ASAN on slower runners. + DISCOVERY_TIMEOUT 120) file(COPY "${CMAKE_CURRENT_SOURCE_DIR}/testNet.net" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/tests/Tensor_strides_test.cpp b/tests/Tensor_strides_test.cpp new file mode 100644 index 000000000..3fe287d32 --- /dev/null +++ b/tests/Tensor_strides_test.cpp @@ -0,0 +1,78 @@ +#include + +#include +#include + +#include "Tensor.hpp" +#include "Type.hpp" +#include "random.hpp" + +namespace { + + // For a tensor T (possibly permuted / non-contiguous), the storage offset + // implied by strides() at multi-index idx must equal the offset that at(idx) + // actually reads -- otherwise the trace's stride-aware diagonal sum would + // disagree with Tensor::at across permutations. + template + void ExpectStridesMatchAt(const cytnx::Tensor& tensor) { + const auto shape = tensor.shape(); + const auto strides = tensor.strides(); + ASSERT_EQ(strides.size(), shape.size()); + std::vector idx(shape.size(), 0); + std::function recurse = [&](std::size_t axis) { + if (axis == shape.size()) { + cytnx::cytnx_uint64 offset = 0; + for (std::size_t a = 0; a < shape.size(); ++a) offset += idx[a] * strides[a]; + EXPECT_EQ(tensor.at(idx), tensor.storage().at(offset)); + } else { + for (idx[axis] = 0; idx[axis] < shape[axis]; ++idx[axis]) recurse(axis + 1); + } + }; + recurse(0); + } + + TEST(TensorStridesTest, ContiguousIsRowMajor) { + auto tensor = + cytnx::random::random_tensor({4, 5, 3}, -1.0, 1.0, cytnx::Device.cpu, 0, cytnx::Type.Double); + ASSERT_TRUE(tensor.is_contiguous()); + const auto s = tensor.strides(); + // Row-major: strides = (5*3, 3, 1). + ASSERT_EQ(s.size(), 3u); + EXPECT_EQ(s[0], 15u); + EXPECT_EQ(s[1], 3u); + EXPECT_EQ(s[2], 1u); + ExpectStridesMatchAt(tensor); + } + + TEST(TensorStridesTest, PermutedMatchesAtForRanks2to5) { + using cytnx::cytnx_uint64; + struct Case { + std::vector shape; + std::vector perm; + }; + const std::vector cases = { + {{3, 4}, {1, 0}}, + {{3, 4, 5}, {2, 0, 1}}, + {{2, 3, 4, 2}, {3, 1, 0, 2}}, + {{2, 3, 2, 3, 2}, {4, 2, 0, 3, 1}}, + }; + for (const auto& c : cases) { + auto t = + cytnx::random::random_tensor(c.shape, -1.0, 1.0, cytnx::Device.cpu, 0, cytnx::Type.Double); + auto p = t.permute(c.perm); + EXPECT_FALSE(p.is_contiguous()) << "expected the permutation to be non-contiguous"; + ExpectStridesMatchAt(p); + } + } + + TEST(TensorStridesTest, ComplexAndIntegerDtypes) { + auto td = cytnx::random::random_tensor({3, 4, 2}, -1.0, 1.0, cytnx::Device.cpu, 0, + cytnx::Type.ComplexDouble); + ExpectStridesMatchAt(td.permute({2, 0, 1})); + + auto ti = + cytnx::random::random_tensor({3, 4, 2}, 0.0, 10.0, cytnx::Device.cpu, 0, cytnx::Type.Int32); + ExpectStridesMatchAt(ti.permute({1, 2, 0})); + } + +} // namespace diff --git a/tests/linalg_test/Trace_test.cpp b/tests/linalg_test/Trace_test.cpp new file mode 100644 index 000000000..1f3e6b58d --- /dev/null +++ b/tests/linalg_test/Trace_test.cpp @@ -0,0 +1,104 @@ +#include + +#include + +#include "Tensor.hpp" +#include "Type.hpp" +#include "linalg.hpp" +#include "random.hpp" +#include "test_tools.h" + +namespace { + + using cytnx::cytnx_int32; + using cytnx::cytnx_int64; + using cytnx::cytnx_uint64; + using cytnx::Device; + using cytnx::Tensor; + using cytnx::Type; + + // The strided in-place trace must agree with the trace of a fully materialized + // contiguous clone of the same tensor (which is the layout the old code always + // assumed). Pairing both via the same public API isolates the layout choice. + static Tensor ContiguousReferenceTrace(const Tensor& t, cytnx_uint64 a, cytnx_uint64 b) { + return cytnx::linalg::Trace(t.contiguous(), a, b); + } + + TEST(LinalgTraceTest, PermutedRank3MatchesContiguous) { + auto t = cytnx::random::random_tensor({4, 3, 4}, -1.0, 1.0, Device.cpu, 0, Type.Double); + auto p = t.permute({2, 1, 0}); // shape {4, 3, 4} but non-contiguous + ASSERT_FALSE(p.is_contiguous()); + auto strided = cytnx::linalg::Trace(p, 0, 2); + auto reference = ContiguousReferenceTrace(p, 0, 2); + EXPECT_TRUE(cytnx::TestTools::AreNearlyEqTensor(strided, reference, 1e-12)); + } + + TEST(LinalgTraceTest, PermutedTracesAcrossRanksAndDtypes) { + struct Case { + std::vector shape; + std::vector perm; + cytnx_uint64 ax1, ax2; + }; + // Each case picks two equal-sized logical axes after a non-trivial permute. + const std::vector cases = { + // rank 3 + {{4, 3, 4}, {1, 0, 2}, 1, 2}, + // rank 4 with non-adjacent traced axes (permuted shape = {4, 4, 5, 3}) + {{4, 3, 4, 5}, {2, 0, 3, 1}, 0, 1}, + // rank 5 + {{3, 4, 3, 2, 3}, {4, 1, 0, 3, 2}, 0, 2}, + }; + for (unsigned int dtype : {Type.Double, Type.ComplexDouble, Type.Int32}) { + for (const auto& c : cases) { + auto t = cytnx::random::random_tensor(c.shape, -2.0, 2.0, Device.cpu, 0, dtype); + auto p = t.permute(c.perm); + EXPECT_FALSE(p.is_contiguous()); + auto strided = cytnx::linalg::Trace(p, c.ax1, c.ax2); + auto reference = ContiguousReferenceTrace(p, c.ax1, c.ax2); + EXPECT_TRUE(cytnx::TestTools::AreNearlyEqTensor(strided, reference, 1e-10)) + << "dtype=" << dtype << " rank=" << c.shape.size(); + } + } + } + + TEST(LinalgTraceTest, Rank2Path) { + // Exercises _trace_2d / cuTrace_2d_kernel (the 2d branch is only taken when + // every remaining axis has been traced away). + auto t = cytnx::random::random_tensor({6, 6}, -1.0, 1.0, Device.cpu, 0, Type.Double); + auto p = t.permute({1, 0}); + ASSERT_FALSE(p.is_contiguous()); + auto out_p = cytnx::linalg::Trace(p, 0, 1); + auto out_c = cytnx::linalg::Trace(t, 0, 1); // tr(A) == tr(A^T) + ASSERT_EQ(out_p.shape().size(), 1u); + ASSERT_EQ(out_p.shape()[0], 1u); + EXPECT_TRUE(cytnx::TestTools::AreNearlyEqTensor(out_p, out_c, 1e-12)); + } + + TEST(LinalgTraceTest, OutputRankIsInputMinusTwo) { + // tr(rank-N) -> rank-(N-2); tr(rank-2) -> a 1-element rank-1 tensor. + auto r4 = cytnx::random::random_tensor({2, 3, 2, 4}, -1.0, 1.0, Device.cpu, 0, Type.Double); + auto out4 = cytnx::linalg::Trace(r4, 0, 2); + EXPECT_EQ(out4.shape().size(), 2u); + EXPECT_EQ(out4.shape()[0], 3u); + EXPECT_EQ(out4.shape()[1], 4u); + + auto r3 = cytnx::random::random_tensor({3, 4, 3}, -1.0, 1.0, Device.cpu, 0, Type.Double); + auto out3 = cytnx::linalg::Trace(r3, 0, 2); + EXPECT_EQ(out3.shape().size(), 1u); + EXPECT_EQ(out3.shape()[0], 4u); + + auto r2 = cytnx::random::random_tensor({5, 5}, -1.0, 1.0, Device.cpu, 0, Type.Double); + auto out2 = cytnx::linalg::Trace(r2, 0, 1); + EXPECT_EQ(out2.shape().size(), 1u); + EXPECT_EQ(out2.shape()[0], 1u); + } + + TEST(LinalgTraceTest, SwappedAxisOrderMatches) { + // Trace(T, a, b) == Trace(T, b, a) (the function normalizes the order). + auto t = cytnx::random::random_tensor({3, 4, 3, 2}, -1.0, 1.0, Device.cpu, 0, Type.Double); + auto ab = cytnx::linalg::Trace(t, 0, 2); + auto ba = cytnx::linalg::Trace(t, 2, 0); + EXPECT_TRUE(cytnx::TestTools::AreNearlyEqTensor(ab, ba, 1e-12)); + } + +} // namespace