From 50b8448cded5850b1df44f463e81ef616871709a Mon Sep 17 00:00:00 2001 From: Philip Adams Date: Thu, 19 May 2022 17:51:11 -0700 Subject: [PATCH 1/4] Try using dispatcher instead of macro for our enum->type template mapping --- AnnService/inc/Core/Common.h | 64 +++++++++++++++++-- AnnService/inc/Core/Common/BKTree.h | 30 ++------- AnnService/inc/Core/Common/KDTree.h | 33 +++------- .../inc/Core/Common/NeighborhoodGraph.h | 17 ++--- .../src/Aggregator/AggregatorService.cpp | 40 +++++------- AnnService/src/Core/VectorSet.cpp | 17 ++--- Test/src/SSDServingTest.cpp | 17 ++--- 7 files changed, 107 insertions(+), 111 deletions(-) diff --git a/AnnService/inc/Core/Common.h b/AnnService/inc/Core/Common.h index 1bad007dd..d2b649aab 100644 --- a/AnnService/inc/Core/Common.h +++ b/AnnService/inc/Core/Common.h @@ -14,6 +14,7 @@ #include #include "inc/Helper/Logging.h" #include "inc/Helper/DiskIO.h" +#include #ifndef _MSC_VER #include @@ -155,14 +156,69 @@ static_assert(static_cast(DistCalcMethod::Undefined) != 0, "Empty enum class VectorValueType : std::uint8_t { -#define DefineVectorValueType(Name, Type) Name, -#include "DefinitionList.h" -#undef DefineVectorValueType - + Int8, + UInt8, + Int16, + Float, Undefined }; static_assert(static_cast(VectorValueType::Undefined) != 0, "Empty VectorValueType!"); +using VectorValueTypeTuple = std::tuple; + +template +std::function call_with_default(F&& f) +{ + return [f]() {f(T{}); }; +} + +template +void VectorValueTypeDispatch(VectorValueType vectorType, F&& f, std::index_sequence) +{ + std::function fs[] = { + call_with_default>(f)... + }; + fs[static_cast(vectorType)](); +} + +template +void VectorValueTypeDispatch(VectorValueType vectorType, F f) +{ + constexpr auto VectorCount = std::tuple_size::value; + VectorValueTypeDispatch(vectorType, f, std::make_index_sequence{}); +} + + +/*template struct VectorValueTypeMap_t {}; +template <> struct VectorValueTypeMap_t +{ + using type = std::int8_t; +}; +template <> struct VectorValueTypeMap_t { + using type = std::uint8_t; +}; +template <> struct VectorValueTypeMap_t { + using type = std::int16_t; +}; +template <> struct VectorValueTypeMap_t { + using type = std::float_t; +}; + +template +using VectorValueTypeMap = VectorValueTypeMap_t::type; + +template +static void VectorValueTypeDispatch(VectorValueType t) +{ + constexpr for (int i = 0; i < VectorValueType::Undefined; i++) + { + if ((VectorValueType)i == t) + { + Functor.template operator() < VectorValueTypeMap > (); + } + } +}*/ + enum class IndexAlgoType : std::uint8_t { diff --git a/AnnService/inc/Core/Common/BKTree.h b/AnnService/inc/Core/Common/BKTree.h index b976b25f6..9fa7b6566 100644 --- a/AnnService/inc/Core/Common/BKTree.h +++ b/AnnService/inc/Core/Common/BKTree.h @@ -421,18 +421,7 @@ namespace SPTAG float CountStd; if (args.m_pQuantizer) { - switch (args.m_pQuantizer->GetReconstructType()) - { -#define DefineVectorValueType(Name, Type) \ -case VectorValueType::Name: \ -CountStd = TryClustering(data, indices, first, last, args, samples, lambdaFactor, true); \ -break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + VectorValueTypeDispatch(args.m_pQuantizer->GetReconstructType(), [&](auto t) { CountStd = TryClustering(data, indices, first, last, args, samples, lambdaFactor, true); }); } else { @@ -469,18 +458,11 @@ break; if (args.m_pQuantizer) { - switch (args.m_pQuantizer->GetReconstructType()) - { -#define DefineVectorValueType(Name, Type) \ -case VectorValueType::Name: \ -TryClustering(data, indices, first, last, args, samples, lambdaFactor, debug, abort); \ -break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + VectorValueTypeDispatch(args.m_pQuantizer->GetReconstructType(), [&](auto t) + { + using Type = decltype(t); + TryClustering(data, indices, first, last, args, samples, lambdaFactor, debug, abort); + }); } else { diff --git a/AnnService/inc/Core/Common/KDTree.h b/AnnService/inc/Core/Common/KDTree.h index 8d16c2f68..59f45e05f 100644 --- a/AnnService/inc/Core/Common/KDTree.h +++ b/AnnService/inc/Core/Common/KDTree.h @@ -63,18 +63,11 @@ namespace SPTAG { if (m_pQuantizer) { - switch (m_pQuantizer->GetReconstructType()) - { -#define DefineVectorValueType(Name, Type) \ -case VectorValueType::Name: \ -BuildTreesCore(data, numOfThreads, indices, abort); \ -break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + VectorValueTypeDispatch(m_pQuantizer->GetReconstructType(), [&](auto t) + { + using Type = decltype(t); + BuildTreesCore(data, numOfThreads, indices, abort); + }); } else { @@ -236,17 +229,11 @@ break; { if (m_pQuantizer) { - switch (m_pQuantizer->GetReconstructType()) - { -#define DefineVectorValueType(Name, Type) \ -case VectorValueType::Name: \ -return KDTSearchCore(p_data, fComputeDistance, p_query, p_space, node, distBound); - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + return VectorValueTypeDispatch(m_pQuantizer->GetReconstructType(), [&](auto t) + { + using Type = decltype(t); + KDTSearchCore(p_data, fComputeDistance, p_query, p_space, node, distBound); + }); } else { diff --git a/AnnService/inc/Core/Common/NeighborhoodGraph.h b/AnnService/inc/Core/Common/NeighborhoodGraph.h index 807bd1f9e..41ad1700e 100644 --- a/AnnService/inc/Core/Common/NeighborhoodGraph.h +++ b/AnnService/inc/Core/Common/NeighborhoodGraph.h @@ -129,18 +129,11 @@ namespace SPTAG { if (index->m_pQuantizer) { - switch (index->m_pQuantizer->GetReconstructType()) - { -#define DefineVectorValueType(Name, Type) \ -case VectorValueType::Name: \ -PartitionByTptreeCore(index, indices, first, last, leaves); \ -break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + return VectorValueTypeDispatch(index->m_pQuantizer->GetReconstructType(), [&](auto t) + { + using Type = decltype(t); + PartitionByTptreeCore(index, indices, first, last, leaves); + }); } else { diff --git a/AnnService/src/Aggregator/AggregatorService.cpp b/AnnService/src/Aggregator/AggregatorService.cpp index 96a3ce726..2744739b0 100644 --- a/AnnService/src/Aggregator/AggregatorService.cpp +++ b/AnnService/src/Aggregator/AggregatorService.cpp @@ -222,29 +222,23 @@ AggregatorService::SearchRequestHanlder(Socket::ConnectionID p_localConnectionID size_t vectorSize; SizeType vectorDimension = 0; std::vector servers; - switch (context->GetSettings()->m_valueType) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - if (!queryParser.GetVectorElements().empty()) { \ - Service::ConvertVectorFromString(queryParser.GetVectorElements(), vector, vectorDimension); \ - } else if (queryParser.GetVectorBase64() != nullptr && queryParser.GetVectorBase64Length() != 0) { \ - vector = ByteArray::Alloc(Helper::Base64::CapacityForDecode(queryParser.GetVectorBase64Length())); \ - Helper::Base64::Decode(queryParser.GetVectorBase64(), queryParser.GetVectorBase64Length(), vector.Data(), vectorSize); \ - vectorDimension = (SizeType)(vectorSize / GetValueTypeSize(context->GetSettings()->m_valueType)); \ - } \ - for (int i = 0; i < context->GetCenters()->Count(); i++) { \ - servers.push_back(BasicResult(i, COMMON::DistanceUtils::ComputeDistance((Type*)vector.Data(), \ - (Type*)context->GetCenters()->GetVector(i), vectorDimension, context->GetSettings()->m_distMethod))); \ - } \ - break; \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: - break; - } + VectorValueTypeDispatch(context->GetSettings()->m_valueType, [](auto t) + { + using Type = decltype(t); + if (!queryParser.GetVectorElements().empty()) { + Service::ConvertVectorFromString(queryParser.GetVectorElements(), vector, vectorDimension); + } + else if (queryParser.GetVectorBase64() != nullptr && queryParser.GetVectorBase64Length() != 0) { + vector = ByteArray::Alloc(Helper::Base64::CapacityForDecode(queryParser.GetVectorBase64Length())); + Helper::Base64::Decode(queryParser.GetVectorBase64(), queryParser.GetVectorBase64Length(), vector.Data(), vectorSize); + vectorDimension = (SizeType)(vectorSize / GetValueTypeSize(context->GetSettings()->m_valueType)); + } + for (int i = 0; i < context->GetCenters()->Count(); i++) { + servers.push_back(BasicResult(i, COMMON::DistanceUtils::ComputeDistance((Type*)vector.Data(), + (Type*)context->GetCenters()->GetVector(i), vectorDimension, context->GetSettings()->m_distMethod))); + } + }); + std::sort(servers.begin(), servers.end(), [](const BasicResult& a, const BasicResult& b) { return a.Dist < b.Dist; }); for (int i = 0; i < context->GetSettings()->m_topK; i++) { auto& server = context->GetRemoteServers().at(servers[i].VID); diff --git a/AnnService/src/Core/VectorSet.cpp b/AnnService/src/Core/VectorSet.cpp index 9dbd8e960..fb5a2b4ee 100644 --- a/AnnService/src/Core/VectorSet.cpp +++ b/AnnService/src/Core/VectorSet.cpp @@ -101,16 +101,9 @@ SizeType BasicVectorSet::PerVectorDataSize() const void BasicVectorSet::Normalize(int p_threads) { - switch (m_valueType) - { -#define DefineVectorValueType(Name, Type) \ -case SPTAG::VectorValueType::Name: \ -SPTAG::COMMON::Utils::BatchNormalize(reinterpret_cast(m_data.Data()), m_vectorCount, m_dimension, SPTAG::COMMON::Utils::GetBase(), p_threads); \ -break; \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - default: - break; - } + VectorValueTypeDispatch(m_valueType, [&](auto t) + { + using Type = decltype(t); + SPTAG::COMMON::Utils::BatchNormalize(reinterpret_cast(m_data.Data()), m_vectorCount, m_dimension, SPTAG::COMMON::Utils::GetBase(), p_threads); + }); } \ No newline at end of file diff --git a/Test/src/SSDServingTest.cpp b/Test/src/SSDServingTest.cpp index cba53866c..0e9d0ab26 100644 --- a/Test/src/SSDServingTest.cpp +++ b/Test/src/SSDServingTest.cpp @@ -73,19 +73,10 @@ void GenerateVectors(std::string fileName, SPTAG::SizeType rows, SPTAG::Dimensio } -void GenVec(std::string vectorsName, SPTAG::VectorValueType vecType, SPTAG::VectorFileType vecFileType, SPTAG::SizeType rows = 1000, SPTAG::DimensionType dims = 100) { - switch (vecType) - { -#define DefineVectorValueType(Name, Type) \ -case SPTAG::VectorValueType::Name: \ -GenerateVectors(vectorsName, rows, dims, vecFileType); \ -break; \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - default: - break; - } +void GenVec(std::string vectorsName, SPTAG::VectorValueType vecType, SPTAG::VectorFileType vecFileType, SPTAG::SizeType rows = 1000, SPTAG::DimensionType dims = 100) +{ + // Renan: "F# is faster than C++ and we should use it for this algorithm instead" + VectorValueTypeDispatch(vecType, [&](auto t) {GenerateVectors(vectorsName, rows, dims, vecFileType); }); } std::string CreateBaseConfig(SPTAG::VectorValueType p_valueType, SPTAG::DistCalcMethod p_distCalcMethod, From 1f4c0b6c33dbaaf7410abffacb5e9225c39ad24e Mon Sep 17 00:00:00 2001 From: Philip Adams Date: Thu, 19 May 2022 19:56:56 -0700 Subject: [PATCH 2/4] throw on Undefined type, fox aggregator capture mode --- AnnService/inc/Core/Common.h | 1 + AnnService/src/Aggregator/AggregatorService.cpp | 2 +- Test/src/SSDServingTest.cpp | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/AnnService/inc/Core/Common.h b/AnnService/inc/Core/Common.h index d2b649aab..16c12b7f0 100644 --- a/AnnService/inc/Core/Common.h +++ b/AnnService/inc/Core/Common.h @@ -184,6 +184,7 @@ void VectorValueTypeDispatch(VectorValueType vectorType, F&& f, std::index_seque template void VectorValueTypeDispatch(VectorValueType vectorType, F f) { + if (vectorType == VectorValueType::Undefined) throw std::exception("VectorValueTypeDispatch on Undefined type"); constexpr auto VectorCount = std::tuple_size::value; VectorValueTypeDispatch(vectorType, f, std::make_index_sequence{}); } diff --git a/AnnService/src/Aggregator/AggregatorService.cpp b/AnnService/src/Aggregator/AggregatorService.cpp index 2744739b0..faf549cd7 100644 --- a/AnnService/src/Aggregator/AggregatorService.cpp +++ b/AnnService/src/Aggregator/AggregatorService.cpp @@ -222,7 +222,7 @@ AggregatorService::SearchRequestHanlder(Socket::ConnectionID p_localConnectionID size_t vectorSize; SizeType vectorDimension = 0; std::vector servers; - VectorValueTypeDispatch(context->GetSettings()->m_valueType, [](auto t) + VectorValueTypeDispatch(context->GetSettings()->m_valueType, [&](auto t) { using Type = decltype(t); if (!queryParser.GetVectorElements().empty()) { diff --git a/Test/src/SSDServingTest.cpp b/Test/src/SSDServingTest.cpp index 0e9d0ab26..1dfcc71f8 100644 --- a/Test/src/SSDServingTest.cpp +++ b/Test/src/SSDServingTest.cpp @@ -75,7 +75,6 @@ void GenerateVectors(std::string fileName, SPTAG::SizeType rows, SPTAG::Dimensio void GenVec(std::string vectorsName, SPTAG::VectorValueType vecType, SPTAG::VectorFileType vecFileType, SPTAG::SizeType rows = 1000, SPTAG::DimensionType dims = 100) { - // Renan: "F# is faster than C++ and we should use it for this algorithm instead" VectorValueTypeDispatch(vecType, [&](auto t) {GenerateVectors(vectorsName, rows, dims, vecFileType); }); } From cb324f53a8ac35393c1f5ece38ef36c089de3688 Mon Sep 17 00:00:00 2001 From: Philip Adams Date: Thu, 19 May 2022 21:19:22 -0700 Subject: [PATCH 3/4] Add remove_last interface to allow generating the variadic list using the macro interface --- AnnService/inc/Core/Common.h | 96 +++++++++---------- AnnService/src/Core/Common/IQuantizer.cpp | 48 +--------- AnnService/src/Core/VectorIndex.cpp | 56 +++-------- .../src/Helper/VectorSetReaders/TxtReader.cpp | 19 +--- AnnService/src/IndexSearcher/main.cpp | 15 +-- AnnService/src/Quantizer/main.cpp | 15 +-- AnnService/src/SSDServing/main.cpp | 36 +++---- .../src/Server/SearchExecutionContext.cpp | 19 ++-- 8 files changed, 95 insertions(+), 209 deletions(-) diff --git a/AnnService/inc/Core/Common.h b/AnnService/inc/Core/Common.h index 16c12b7f0..c098ade77 100644 --- a/AnnService/inc/Core/Common.h +++ b/AnnService/inc/Core/Common.h @@ -153,19 +153,46 @@ enum class DistCalcMethod : std::uint8_t }; static_assert(static_cast(DistCalcMethod::Undefined) != 0, "Empty DistCalcMethod!"); - enum class VectorValueType : std::uint8_t { - Int8, - UInt8, - Int16, - Float, +#define DefineVectorValueType(Name, Type) Name, +#include "DefinitionList.h" +#undef DefineVectorValueType Undefined }; static_assert(static_cast(VectorValueType::Undefined) != 0, "Empty VectorValueType!"); -using VectorValueTypeTuple = std::tuple; +// remove_last is by Vladimir Reshetnikov, https://stackoverflow.com/a/51805324 +template +struct remove_last; + +template<> +struct remove_last>; // Define as you wish or leave undefined + +template +struct remove_last> +{ +private: + using Tuple = std::tuple; + + template + static std::tuple...> + extract(std::index_sequence); + +public: + using type = decltype(extract(std::make_index_sequence())); +}; + +template +using remove_last_t = typename remove_last::type; +using VectorValueTypeTuple = remove_last_t>; + +// Dispatcher is based on https://stackoverflow.com/a/34046180 template std::function call_with_default(F&& f) { @@ -179,47 +206,22 @@ void VectorValueTypeDispatch(VectorValueType vectorType, F&& f, std::index_seque call_with_default>(f)... }; fs[static_cast(vectorType)](); + } template void VectorValueTypeDispatch(VectorValueType vectorType, F f) { - if (vectorType == VectorValueType::Undefined) throw std::exception("VectorValueTypeDispatch on Undefined type"); constexpr auto VectorCount = std::tuple_size::value; - VectorValueTypeDispatch(vectorType, f, std::make_index_sequence{}); -} - - -/*template struct VectorValueTypeMap_t {}; -template <> struct VectorValueTypeMap_t -{ - using type = std::int8_t; -}; -template <> struct VectorValueTypeMap_t { - using type = std::uint8_t; -}; -template <> struct VectorValueTypeMap_t { - using type = std::int16_t; -}; -template <> struct VectorValueTypeMap_t { - using type = std::float_t; -}; - -template -using VectorValueTypeMap = VectorValueTypeMap_t::type; - -template -static void VectorValueTypeDispatch(VectorValueType t) -{ - constexpr for (int i = 0; i < VectorValueType::Undefined; i++) + if ((int)vectorType < VectorCount) { - if ((VectorValueType)i == t) - { - Functor.template operator() < VectorValueTypeMap > (); - } + VectorValueTypeDispatch(vectorType, f, std::make_index_sequence{}); } -}*/ - + else + { + throw std::exception(); + } +} enum class IndexAlgoType : std::uint8_t { @@ -271,20 +273,10 @@ constexpr VectorValueType GetEnumValueType() \ inline std::size_t GetValueTypeSize(VectorValueType p_valueType) { - switch (p_valueType) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - return sizeof(Type); \ - -#include "DefinitionList.h" -#undef DefineVectorValueType - - default: - break; - } + std::size_t out = 0; + VectorValueTypeDispatch(p_valueType, [&](auto t) { out = sizeof(decltype(t)); }); - return 0; + return out; } enum class QuantizerType : std::uint8_t diff --git a/AnnService/src/Core/Common/IQuantizer.cpp b/AnnService/src/Core/Common/IQuantizer.cpp index 5f2629c26..9c2a0bb58 100644 --- a/AnnService/src/Core/Common/IQuantizer.cpp +++ b/AnnService/src/Core/Common/IQuantizer.cpp @@ -20,31 +20,11 @@ namespace SPTAG case QuantizerType::Undefined: break; case QuantizerType::PQQuantizer: - switch (reconstructType) { - #define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - ret.reset(new PQQuantizer()); \ - break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } - + VectorValueTypeDispatch(reconstructType, [&](auto t) {ret.reset(new PQQuantizer()); }); if (ret->LoadQuantizer(p_in) != ErrorCode::Success) ret.reset(); return ret; case QuantizerType::OPQQuantizer: - switch (reconstructType) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - ret.reset(new OPQQuantizer()); \ - break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - default: break; - } + VectorValueTypeDispatch(reconstructType, [&](auto t) {ret.reset(new OPQQuantizer()); }); if (ret->LoadQuantizer(p_in) != ErrorCode::Success) ret.reset(); return ret; } @@ -68,31 +48,11 @@ namespace SPTAG case QuantizerType::Undefined: return ret; case QuantizerType::PQQuantizer: - switch (reconstructType) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - ret.reset(new PQQuantizer()); \ - break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - default: break; - } - + VectorValueTypeDispatch(reconstructType, [&](auto t) {ret.reset(new PQQuantizer()); }); if (ret->LoadQuantizer(raw_bytes) != ErrorCode::Success) ret.reset(); return ret; case QuantizerType::OPQQuantizer: - switch (reconstructType) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - ret.reset(new OPQQuantizer()); \ - break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - default: break; - } - + VectorValueTypeDispatch(reconstructType, [&](auto t) {ret.reset(new PQQuantizer()); }); if (ret->LoadQuantizer(raw_bytes) != ErrorCode::Success) ret.reset(); return ret; } diff --git a/AnnService/src/Core/VectorIndex.cpp b/AnnService/src/Core/VectorIndex.cpp index 60a194090..7a4337b2e 100644 --- a/AnnService/src/Core/VectorIndex.cpp +++ b/AnnService/src/Core/VectorIndex.cpp @@ -540,45 +540,18 @@ VectorIndex::CreateInstance(IndexAlgoType p_algo, VectorValueType p_valuetype) { return nullptr; } - + std::shared_ptr out; if (p_algo == IndexAlgoType::BKT) { - switch (p_valuetype) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - return std::shared_ptr(new BKT::Index); \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + VectorValueTypeDispatch(p_valuetype, [&](auto t) {out.reset(new BKT::Index); }); + return out; } else if (p_algo == IndexAlgoType::KDT) { - switch (p_valuetype) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - return std::shared_ptr(new KDT::Index); \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + VectorValueTypeDispatch(p_valuetype, [&](auto t) {out.reset(new KDT::Index); }); + return out; } else if (p_algo == IndexAlgoType::SPANN) { - switch (p_valuetype) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - return std::shared_ptr(new SPANN::Index); \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + VectorValueTypeDispatch(p_valuetype, [&](auto t) {out.reset(new SPANN::Index); }); + return out; } return nullptr; } @@ -930,16 +903,11 @@ void VectorIndex::ApproximateRNG(std::shared_ptr& fullVectors, std::u { reconstructed_vector = ALIGN_ALLOC(m_pQuantizer->ReconstructSize()); m_pQuantizer->ReconstructVector((const uint8_t*)fullVectors->GetVector(fullID), reconstructed_vector); - switch (m_pQuantizer->GetReconstructType()) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - (*((COMMON::QueryResultSet*)&resultSet)).SetTarget(reinterpret_cast(reconstructed_vector), m_pQuantizer); \ - break; -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - default: - LOG(Helper::LogLevel::LL_Error, "Unable to get quantizer reconstruct type %s", Helper::Convert::ConvertToString(m_pQuantizer->GetReconstructType())); - } + VectorValueTypeDispatch(m_pQuantizer->GetReconstructType(), [&](auto t) + { + using Type = decltype(t); + (*((COMMON::QueryResultSet*) & resultSet)).SetTarget(reinterpret_cast(reconstructed_vector), m_pQuantizer); + }); } else { diff --git a/AnnService/src/Helper/VectorSetReaders/TxtReader.cpp b/AnnService/src/Helper/VectorSetReaders/TxtReader.cpp index 7b97f84f8..812620b4b 100644 --- a/AnnService/src/Helper/VectorSetReaders/TxtReader.cpp +++ b/AnnService/src/Helper/VectorSetReaders/TxtReader.cpp @@ -222,20 +222,11 @@ TxtVectorReader::LoadFileInternal(const std::string& p_filePath, } bool parseSuccess = false; - switch (m_options->m_inputValueType) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - parseSuccess = TranslateVector(currentLine.get() + tabIndex + 1, reinterpret_cast(vector.get())); \ - break; \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: - parseSuccess = false; - break; - } + VectorValueTypeDispatch(m_options->m_inputValueType, [&](auto t) + { + using Type = decltype(t); + parseSuccess = TranslateVector(currentLine.get() + tabIndex + 1, reinterpret_cast(vector.get())); + }); if (!parseSuccess) { diff --git a/AnnService/src/IndexSearcher/main.cpp b/AnnService/src/IndexSearcher/main.cpp index 9bc376450..cc12f0d69 100644 --- a/AnnService/src/IndexSearcher/main.cpp +++ b/AnnService/src/IndexSearcher/main.cpp @@ -330,17 +330,10 @@ int main(int argc, char** argv) vecIndex->UpdateIndex(); - switch (options->m_inputValueType) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - Process(options, *(vecIndex.get())); \ - break; \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType + VectorValueTypeDispatch(options->m_inputValueType, [&](auto t) + { + Process(options, *(vecIndex.get())); + }); - default: break; - } return 0; } diff --git a/AnnService/src/Quantizer/main.cpp b/AnnService/src/Quantizer/main.cpp index e4283c039..f95e15b28 100644 --- a/AnnService/src/Quantizer/main.cpp +++ b/AnnService/src/Quantizer/main.cpp @@ -54,16 +54,11 @@ int main(int argc, char* argv[]) { LOG(Helper::LogLevel::LL_Info, "Quantizer Does not exist. Training a new one.\n"); - switch (options->m_inputValueType) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - quantizer.reset(new COMMON::PQQuantizer(options->m_quantizedDim, 256, (DimensionType)(options->m_dimension/options->m_quantizedDim), false, TrainPQQuantizer(options, fullvectors, quantized_vectors))); \ - break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - } + VectorValueTypeDispatch(options->m_inputValueType, [&](auto t) + { + using Type = decltype(t); + quantizer.reset(new COMMON::PQQuantizer(options->m_quantizedDim, 256, (DimensionType)(options->m_dimension / options->m_quantizedDim), false, TrainPQQuantizer(options, fullvectors, quantized_vectors))); + }); auto ptr = SPTAG::f_createIO(); if (ptr != nullptr && ptr->Initialize(options->m_outputQuantizerFile.c_str(), std::ios::binary | std::ios::out)) diff --git a/AnnService/src/SSDServing/main.cpp b/AnnService/src/SSDServing/main.cpp index 7aa372b58..e0510ad66 100644 --- a/AnnService/src/SSDServing/main.cpp +++ b/AnnService/src/SSDServing/main.cpp @@ -107,13 +107,8 @@ namespace SPTAG { SPANN::Options* opts = nullptr; -#define DefineVectorValueType(Name, Type) \ - if (index->GetVectorValueType() == VectorValueType::Name) { \ - opts = ((SPANN::Index*)index.get())->GetOptions(); \ - } \ + VectorValueTypeDispatch(opts->m_valueType, [&](auto t) { opts = ((SPANN::Index*)index.get())->GetOptions(); }); -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType if (opts == nullptr) { LOG(Helper::LogLevel::LL_Error, "Cannot get options.\n"); @@ -149,26 +144,25 @@ namespace SPTAG { omp_set_num_threads(opts->m_iSSDNumberOfThreads); -#define DefineVectorValueType(Name, Type) \ - if (opts->m_valueType == VectorValueType::Name) { \ - COMMON::TruthSet::GenerateTruth(querySet, vectorSet, opts->m_truthPath, \ - distCalcMethod, opts->m_resultNum, opts->m_truthType, index->m_pQuantizer); \ - } \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType + VectorValueTypeDispatch(opts->m_valueType, [&](auto t) + { + COMMON::TruthSet::GenerateTruth(querySet, + vectorSet, + opts->m_truthPath, + distCalcMethod, + opts->m_resultNum, + opts->m_truthType, + index->m_pQuantizer); + }); LOG(Helper::LogLevel::LL_Info, "End generating truth.\n"); } if (searchSSD) { -#define DefineVectorValueType(Name, Type) \ - if (opts->m_valueType == VectorValueType::Name) { \ - SSDIndex::Search((SPANN::Index*)(index.get())); \ - } \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType + VectorValueTypeDispatch(opts->m_valueType, [&](auto t) + { + SSDIndex::Search((SPANN::Index*)(index.get())); + }); } return 0; } diff --git a/AnnService/src/Server/SearchExecutionContext.cpp b/AnnService/src/Server/SearchExecutionContext.cpp index 45d83ec43..88a7c5a82 100644 --- a/AnnService/src/Server/SearchExecutionContext.cpp +++ b/AnnService/src/Server/SearchExecutionContext.cpp @@ -82,19 +82,12 @@ SearchExecutionContext::ExtractVector(VectorValueType p_targetType) { if (!m_queryParser.GetVectorElements().empty()) { - switch (p_targetType) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - return ConvertVectorFromString(m_queryParser.GetVectorElements(), m_vector, m_vectorDimension); \ - break; \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: - break; - } + ErrorCode err; + VectorValueTypeDispatch(p_targetType, [&](auto t) + { + err = ConvertVectorFromString(m_queryParser.GetVectorElements(), m_vector, m_vectorDimension); + }); + return err; } else if (m_queryParser.GetVectorBase64() != nullptr && m_queryParser.GetVectorBase64Length() != 0) From 1e8797846bad538c2b609c0da363e7156fed2f97 Mon Sep 17 00:00:00 2001 From: Philip Adams Date: Thu, 19 May 2022 22:35:50 -0700 Subject: [PATCH 4/4] Fix RUN_FROM_MAP and KDTReconstructTest tests --- AnnService/src/SSDServing/main.cpp | 2 +- Test/src/ReconstructIndexSimilarityTest.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/AnnService/src/SSDServing/main.cpp b/AnnService/src/SSDServing/main.cpp index e0510ad66..8f48519d3 100644 --- a/AnnService/src/SSDServing/main.cpp +++ b/AnnService/src/SSDServing/main.cpp @@ -107,7 +107,7 @@ namespace SPTAG { SPANN::Options* opts = nullptr; - VectorValueTypeDispatch(opts->m_valueType, [&](auto t) { opts = ((SPANN::Index*)index.get())->GetOptions(); }); + VectorValueTypeDispatch(index->GetVectorValueType(), [&](auto t) { opts = ((SPANN::Index*)index.get())->GetOptions(); }); if (opts == nullptr) { diff --git a/Test/src/ReconstructIndexSimilarityTest.cpp b/Test/src/ReconstructIndexSimilarityTest.cpp index 2da00493b..59e11fc09 100644 --- a/Test/src/ReconstructIndexSimilarityTest.cpp +++ b/Test/src/ReconstructIndexSimilarityTest.cpp @@ -165,7 +165,7 @@ void GenerateReconstructData(std::shared_ptr& real_vecset, std::share if (ptr == nullptr || !ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::in)) { BOOST_ASSERT("Canot Open CODEBOOK_FILE to read!" == "Error"); } - quantizer->LoadIQuantizer(ptr); + quantizer = COMMON::IQuantizer::LoadIQuantizer(ptr); BOOST_ASSERT(quantizer); std::shared_ptr options(new Helper::ReaderOptions(GetEnumValueType(), m, VectorFileType::DEFAULT)); @@ -247,7 +247,7 @@ void GenerateReconstructData(std::shared_ptr& real_vecset, std::share if (!ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::in)) { BOOST_ASSERT("Canot Open CODEBOOK_FILE to read!" == "Error"); } - quantizer->LoadIQuantizer(ptr); + quantizer = COMMON::IQuantizer::LoadIQuantizer(ptr); BOOST_ASSERT(quantizer); rec_vecset.reset(new BasicVectorSet(ByteArray::Alloc(sizeof(R) * n * m), GetEnumValueType(), m, n));