diff --git a/bindings/cpp/CMakeLists.txt b/bindings/cpp/CMakeLists.txt index b10cfda12..7869c4822 100644 --- a/bindings/cpp/CMakeLists.txt +++ b/bindings/cpp/CMakeLists.txt @@ -123,14 +123,14 @@ if (SVS_RUNTIME_ENABLE_LVQ_LEANVEC) else() # Links to LTO-enabled static library, requires GCC/G++ 11.2 if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL "11.2" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS "11.3") - set(SVS_URL "https://github.com/intel/ScalableVectorSearch/releases/download/v0.2.0/svs-shared-library-0.2.0-lto-ivf.tar.gz" + set(SVS_URL "https://github.com/intel/ScalableVectorSearch/releases/download/v0.3.0/svs-shared-library-0.3.0-lto-ivf.tar.gz" CACHE STRING "URL to download SVS shared library") else() message(WARNING "Pre-built LVQ/LeanVec SVS library requires GCC/G++ v.11.2 to apply LTO optimizations." "Current compiler: ${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}" ) - set(SVS_URL "https://github.com/intel/ScalableVectorSearch/releases/download/v0.2.0/svs-shared-library-0.2.0-ivf.tar.gz" + set(SVS_URL "https://github.com/intel/ScalableVectorSearch/releases/download/v0.3.0/svs-shared-library-0.3.0-ivf.tar.gz" CACHE STRING "URL to download SVS shared library") endif() include(FetchContent) diff --git a/bindings/cpp/src/dynamic_vamana_index_impl.h b/bindings/cpp/src/dynamic_vamana_index_impl.h index 4b16cf4bc..b27958703 100644 --- a/bindings/cpp/src/dynamic_vamana_index_impl.h +++ b/bindings/cpp/src/dynamic_vamana_index_impl.h @@ -125,8 +125,20 @@ class DynamicVamanaIndexImpl { auto query = queries.get_datum(i); auto iterator = impl_->batch_iterator(query); size_t found = 0; + size_t total_checked = 0; + // Use adaptive batch sizing: start with at least k candidates, + // then adjust based on observed filter hit rate. + auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size()); + const auto max_batch_size = batch_size; do { - iterator.next(k); + // Estimate how many candidates we need to find remaining + // results given the observed hit rate so far. + batch_size = + predict_further_processing(total_checked, found, k, batch_size); + // Cap to avoid oversized batches in the iterator. + batch_size = std::min(batch_size, max_batch_size); + iterator.next(batch_size); + total_checked += iterator.size(); for (auto& neighbor : iterator.results()) { if (filter->is_member(neighbor.id())) { result.set(neighbor, i, found); diff --git a/bindings/cpp/src/svs_runtime_utils.h b/bindings/cpp/src/svs_runtime_utils.h index e0d7c68af..6caa1a325 100644 --- a/bindings/cpp/src/svs_runtime_utils.h +++ b/bindings/cpp/src/svs_runtime_utils.h @@ -431,6 +431,19 @@ auto dispatch_storage_kind(StorageKind kind, F&& f, Args&&... args) { } } // namespace storage +// Predict how many more items need to be processed to reach the goal, +// based on the observed hit rate so far. +// If no hits yet, returns `hint` unchanged. +// The caller should cap the result to a max batch size if needed. +inline size_t +predict_further_processing(size_t processed, size_t hits, size_t goal, size_t hint) { + if (hits == 0 || hits >= goal) { + return hint; + } + float batch_size = static_cast(goal - hits) * processed / hits; + return std::max(static_cast(batch_size), size_t{1}); +} + inline svs::threads::ThreadPoolHandle default_threadpool() { return svs::threads::ThreadPoolHandle(svs::threads::OMPThreadPool(omp_get_max_threads()) ); diff --git a/bindings/cpp/src/vamana_index_impl.h b/bindings/cpp/src/vamana_index_impl.h index 4cf58d7e0..2fd1f1452 100644 --- a/bindings/cpp/src/vamana_index_impl.h +++ b/bindings/cpp/src/vamana_index_impl.h @@ -131,8 +131,20 @@ class VamanaIndexImpl { auto query = queries.get_datum(i); auto iterator = get_impl()->batch_iterator(query); size_t found = 0; + size_t total_checked = 0; + // Use adaptive batch sizing: start with at least k candidates, + // then adjust based on observed filter hit rate. + auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size()); + const auto max_batch_size = batch_size; do { - iterator.next(k); + // Estimate how many candidates we need to find remaining + // results given the observed hit rate so far. + batch_size = + predict_further_processing(total_checked, found, k, batch_size); + // Cap to avoid oversized batches in the iterator. + batch_size = std::min(batch_size, max_batch_size); + iterator.next(batch_size); + total_checked += iterator.size(); for (auto& neighbor : iterator.results()) { if (filter->is_member(neighbor.id())) { result.set(neighbor, i, found); diff --git a/bindings/cpp/tests/runtime_test.cpp b/bindings/cpp/tests/runtime_test.cpp index 201375d3c..92b819894 100644 --- a/bindings/cpp/tests/runtime_test.cpp +++ b/bindings/cpp/tests/runtime_test.cpp @@ -501,6 +501,54 @@ CATCH_TEST_CASE("SearchWithIDFilter", "[runtime]") { svs::runtime::v0::DynamicVamanaIndex::destroy(index); } +CATCH_TEST_CASE("SearchWithRestrictiveFilter", "[runtime][filtered_search]") { + const auto& test_data = get_test_data(); + // Build index + svs::runtime::v0::DynamicVamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + svs::runtime::v0::Status status = svs::runtime::v0::DynamicVamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(index != nullptr); + + // Add data + std::vector labels(test_n); + std::iota(labels.begin(), labels.end(), 0); + status = index->add(test_n, labels.data(), test_data.data()); + CATCH_REQUIRE(status.ok()); + + const int nq = 5; + const float* xq = test_data.data(); + const int k = 5; + + // 10% selectivity: accept only IDs 0-9 out of 100 + size_t min_id = 0; + size_t max_id = test_n / 10; + test_utils::IDFilterRange filter(min_id, max_id); + + std::vector distances(nq * k); + std::vector result_labels(nq * k); + + status = + index->search(nq, xq, k, distances.data(), result_labels.data(), nullptr, &filter); + CATCH_REQUIRE(status.ok()); + + // All returned labels must fall inside the filter range + for (int i = 0; i < nq * k; ++i) { + if (svs::runtime::v0::is_specified(result_labels[i])) { + CATCH_REQUIRE(result_labels[i] >= min_id); + CATCH_REQUIRE(result_labels[i] < max_id); + } + } + + svs::runtime::v0::DynamicVamanaIndex::destroy(index); +} + CATCH_TEST_CASE("RangeSearchFunctional", "[runtime]") { const auto& test_data = get_test_data(); // Build index diff --git a/docker/x86_64/manylinux2014/oneAPI.repo b/docker/x86_64/manylinux2014/oneAPI.repo index ba35b673e..ccf387ba8 100644 --- a/docker/x86_64/manylinux2014/oneAPI.repo +++ b/docker/x86_64/manylinux2014/oneAPI.repo @@ -5,3 +5,4 @@ enabled=1 gpgcheck=1 repo_gpgcheck=1 gpgkey=https://yum.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB +sslverify=0 diff --git a/examples/cpp/shared/CMakeLists.txt b/examples/cpp/shared/CMakeLists.txt index 4c6b93db9..cec6b4308 100644 --- a/examples/cpp/shared/CMakeLists.txt +++ b/examples/cpp/shared/CMakeLists.txt @@ -24,7 +24,7 @@ find_package(svs QUIET) if(NOT svs_FOUND) # If sourcing from pip/conda, the following steps are not necessary, simplifying workflow # If not found, download tarball from GitHub release and follow steps to fetch and find - set(SVS_URL "https://github.com/intel/ScalableVectorSearch/releases/download/v0.2.0/svs-shared-library-0.2.0.tar.gz" CACHE STRINGS "URL to download SVS shared library tarball if not found in system") + set(SVS_URL "https://github.com/intel/ScalableVectorSearch/releases/download/v0.3.0/svs-shared-library-0.3.0.tar.gz" CACHE STRINGS "URL to download SVS shared library tarball if not found in system") message(STATUS "SVS not found in system, downloading from: ${SVS_URL}")