Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion quest/src/core/bitwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
#include <intrin.h>
#endif

#if defined(__BMI2__) && (defined(__x86_64__) || defined(_M_X64)) && !defined(__NVCC__) && !defined(__HIP__)
#include <immintrin.h>
#define QUEST_HAVE_BMI2_INTRINSICS 1
#endif

#include "quest/include/types.h"

#include "quest/src/core/inliner.hpp"
Expand Down Expand Up @@ -166,6 +171,24 @@ INLINE int getBitMaskParity(qindex mask) {
INLINE qindex insertBits(qindex number, const int* bitIndices, int numIndices, int bitValue) {

// bitIndices must be strictly increasing

#ifdef QUEST_HAVE_BMI2_INTRINSICS
// Smaller fixed-size cases are usually unrolled by callers already.
if (numIndices > 5) {
unsigned long long insertionMask = 0;

for (int i=0; i<numIndices; i++)
insertionMask |= 1ULL << bitIndices[i];

unsigned long long inserted = _pdep_u64(static_cast<unsigned long long>(number), ~insertionMask);

if (bitValue)
inserted |= insertionMask;

return static_cast<qindex>(inserted);
}
#endif

for (int i=0; i<numIndices; i++)
number = insertBit(number, bitIndices[i], bitValue);

Expand All @@ -188,6 +211,27 @@ INLINE qindex setBits(qindex number, const int* bitIndices, int numIndices, qind
INLINE qindex getValueOfBits(qindex number, const int* bitIndices, int numIndices) {

// bits are arbitrarily ordered, which affects value

#ifdef QUEST_HAVE_BMI2_INTRINSICS
// Smaller fixed-size cases are usually unrolled by callers already.
if (numIndices > 5) {
unsigned long long extractionMask = 0;
bool indicesAreIncreasing = true;
int prevInd = -1;

for (int i=0; i<numIndices; i++) {
int bitInd = bitIndices[i];
extractionMask |= 1ULL << bitInd;
indicesAreIncreasing &= bitInd > prevInd;
prevInd = bitInd;
}

// PEXT returns bits in ascending mask order, matching this API only for sorted indices.
if (indicesAreIncreasing)
return static_cast<qindex>(_pext_u64(static_cast<unsigned long long>(number), extractionMask));
}
#endif

qindex value = 0;

for (int i=0; i<numIndices; i++)
Expand Down Expand Up @@ -379,4 +423,4 @@ INLINE void setToBitsOfInteger(int* bits, qindex number, int numBits) {



#endif // BITWISE_HPP
#endif // BITWISE_HPP
3 changes: 2 additions & 1 deletion tests/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

target_sources(tests
PUBLIC
bitwise.cpp
calculations.cpp
channels.cpp
debug.cpp
Expand All @@ -16,4 +17,4 @@ target_sources(tests
qureg.cpp
trotterisation.cpp
types.cpp
)
)
122 changes: 122 additions & 0 deletions tests/unit/bitwise.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/** @file
* Unit tests of internal bitwise subroutines.
*
* @defgroup unitbitwise Bitwise
* @ingroup unittests
*/

#include <catch2/catch_test_macros.hpp>

#include "quest/src/core/bitwise.hpp"
#include "tests/utils/macros.hpp"

#include <vector>

using std::vector;


/*
* UTILITIES
*/

#define TEST_CATEGORY \
LABEL_UNIT_TAG "[bitwise]"


static qindex getRefInsertedBits(qindex number, const vector<int>& bitIndices, int bitValue) {

qindex out = 0;
int srcInd = 0;
int nextIns = 0;

for (int dstInd=0; dstInd<63; dstInd++) {
if (nextIns < static_cast<int>(bitIndices.size()) && dstInd == bitIndices[nextIns]) {
if (bitValue)
out |= QINDEX_ONE << dstInd;
nextIns++;
} else {
if ((number >> srcInd) & QINDEX_ONE)
out |= QINDEX_ONE << dstInd;
srcInd++;
}
}

return out;
}


static qindex getRefValueOfBits(qindex number, const vector<int>& bitIndices) {

qindex out = 0;

for (int i=0; i<static_cast<int>(bitIndices.size()); i++)
if ((number >> bitIndices[i]) & QINDEX_ONE)
out |= QINDEX_ONE << i;

return out;
}


/**
* TESTS
*
* @ingroup unitbitwise
* @{
*/


TEST_CASE( "insertBits", TEST_CATEGORY ) {

SECTION( LABEL_CORRECTNESS ) {
vector<qindex> numbers = {0, 1, 2, 5, 21, 0x12345, 0x6DB6DB};
vector<vector<int>> bitIndices = {
{},
{0},
{1},
{0, 1},
{1, 3, 5},
{0, 2, 6, 9},
{4, 8, 12, 20}
};

for (auto number: numbers)
for (auto& inds: bitIndices)
for (int bitValue: {0, 1})
REQUIRE( insertBits(number, inds.data(), static_cast<int>(inds.size()), bitValue) == getRefInsertedBits(number, inds, bitValue) );
}

SECTION( LABEL_VALIDATION ) {

// no validation!
SUCCEED( );
}
}


TEST_CASE( "getValueOfBits", TEST_CATEGORY ) {

SECTION( LABEL_CORRECTNESS ) {
vector<qindex> numbers = {0, 1, 2, 5, 0x12345, 0xAAAAAAAA, 0x55555555};
vector<vector<int>> bitIndices = {
{},
{0},
{1},
{0, 1, 2, 3},
{3, 1, 4, 0},
{20, 0, 16, 8, 4}
};

for (auto number: numbers)
for (auto& inds: bitIndices)
REQUIRE( getValueOfBits(number, inds.data(), static_cast<int>(inds.size())) == getRefValueOfBits(number, inds) );
}

SECTION( LABEL_VALIDATION ) {

// no validation!
SUCCEED( );
}
}


/** @} (end defgroup) */