Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions sycl/include/sycl/marray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,10 +457,19 @@ template <typename Type, std::size_t NumElements> class marray {
if constexpr (use_ext_vector_type) {
ext_vector_t LhsVec = sycl::bit_cast<ext_vector_t>(Lhs.MData);
ext_vector_t ResVec = ~LhsVec;
sycl::detail::memcpy_no_adl(Ret.MData, &ResVec, sizeof(ResVec));
if constexpr (std::is_same_v<DataT, bool>) {
for (size_t I = 0; I < NumElements; ++I)
Ret[I] = ResVec[I] & 1;
Comment on lines +460 to +462

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔵 Suggestion: The ResVec[I] & 1 expression relies on a subtle property — that ~(0 or 1) always flips the LSB — and it also wastes the ResVec = ~LhsVec computation on line 459 (the compiler will likely optimize it away, but it's dead work). Consider !LhsVec[I] instead, which mirrors the non-vector path's !Lhs[I] pattern and is immediately readable without needing to reason about bitwise complement:

Suggested change
if constexpr (std::is_same_v<DataT, bool>) {
for (size_t I = 0; I < NumElements; ++I)
Ret[I] = ResVec[I] & 1;
if constexpr (std::is_same_v<DataT, bool>) {
for (size_t I = 0; I < NumElements; ++I)
Ret[I] = !LhsVec[I];

} else {
storeVecResult(Ret.MData, ResVec);
}
} else {
for (size_t I = 0; I < NumElements; ++I) {
Ret[I] = ~Lhs[I];
if constexpr (std::is_same_v<DataT, bool>) {
for (size_t I = 0; I < NumElements; ++I)
Ret[I] = !Lhs[I];
} else {
for (size_t I = 0; I < NumElements; ++I)
Ret[I] = ~Lhs[I];
}
}
return Ret;
Expand Down
99 changes: 99 additions & 0 deletions sycl/test-e2e/Basic/built-ins/marray_bitwise_not_bool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

// Test bitwise NOT operator (~) on marray<bool>

#include <sycl/detail/core.hpp>
#include <sycl/marray.hpp>

#include <cassert>
#include <iostream>

using namespace sycl;

template <size_t N>
bool test_bitwise_not_bool(queue &q, const marray<bool, N> &input,
const marray<bool, N> &expected) {
bool result[N];
{
buffer<bool> out_buf(result, N);
q.submit([&](handler &h) {
accessor out_acc(out_buf, h, write_only);
h.single_task([=]() {
marray<bool, N> res = ~input;
for (size_t i = 0; i < N; ++i) {
out_acc[i] = res[i];
}
});
}).wait();
}

// Verify results
for (size_t i = 0; i < N; ++i) {
if (result[i] != expected[i]) {
std::cout << "FAILED at index " << i << ": input=" << input[i]
<< ", expected=" << expected[i] << ", got=" << result[i]
<< std::endl;
return false;
}
}
return true;
}

int main() {
queue q;

std::cout << "Testing bitwise NOT (~) on marray<bool>\n";

// Test case 1: Size 2
{
marray<bool, 2> input{true, false};
marray<bool, 2> expected{false, true};
assert(test_bitwise_not_bool(q, input, expected) &&
"Test failed for size 2");
}

// Test case 2: Size 4
{
marray<bool, 4> input{true, false, true, false};
marray<bool, 4> expected{false, true, false, true};
assert(test_bitwise_not_bool(q, input, expected) &&
"Test failed for size 4");
}

// Test case 3: Size 3 (no padding, different code path)
{
marray<bool, 3> input{false, true, false};
marray<bool, 3> expected{true, false, true};
assert(test_bitwise_not_bool(q, input, expected) &&
"Test failed for size 3");
}

// Test case 4: Size 8
{
marray<bool, 8> input{true, true, false, false, true, false, true, false};
marray<bool, 8> expected{false, false, true, true,
false, true, false, true};
assert(test_bitwise_not_bool(q, input, expected) &&
"Test failed for size 8");
}

// Test case 5: All true
{
marray<bool, 4> input{true, true, true, true};
marray<bool, 4> expected{false, false, false, false};
assert(test_bitwise_not_bool(q, input, expected) &&
"Test failed for all true");
}

// Test case 6: All false
{
marray<bool, 4> input{false, false, false, false};
marray<bool, 4> expected{true, true, true, true};
assert(test_bitwise_not_bool(q, input, expected) &&
"Test failed for all false");
}

std::cout << "All tests passed!\n";
return 0;
}
Loading