From 059477bd9bddf20d74b7ac9e83071701b01c3dbd Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Tue, 10 Feb 2026 14:25:13 -0600 Subject: [PATCH 1/2] CASSANDRA-20086: Use score ordered iterators for ANN search (cherry picked from commit f2e34117d6a313be14156d92a3485f29424144c1) --- .../config/CassandraRelevantProperties.java | 18 +- .../db/SinglePartitionReadCommand.java | 54 ++- .../cassandra/db/lifecycle/Tracker.java | 6 +- .../cassandra/index/sai/QueryContext.java | 8 +- .../index/sai/StorageAttachedIndexGroup.java | 5 + .../index/sai/VectorQueryContext.java | 196 -------- .../cassandra/index/sai/disk/EmptyIndex.java | 121 +++++ .../index/sai/disk/PrimaryKeyMap.java | 7 + .../index/sai/disk/SSTableIndex.java | 9 +- .../sai/disk/format/IndexDescriptor.java | 5 +- .../sai/disk/v1/SkinnyPrimaryKeyMap.java | 17 +- .../index/sai/disk/v1/V1SSTableIndex.java | 27 +- .../index/sai/disk/v1/WidePrimaryKeyMap.java | 9 +- .../disk/v1/postings/VectorPostingList.java | 76 --- .../disk/v1/segment/IndexSegmentSearcher.java | 21 +- .../index/sai/disk/v1/segment/Segment.java | 26 +- .../sai/disk/v1/segment/SegmentBuilder.java | 2 +- .../sai/disk/v1/segment/SegmentOrdering.java | 22 +- .../segment/VectorIndexSegmentSearcher.java | 292 ++++++----- .../vector/AutoResumingNodeScoreIterator.java | 161 +++++++ .../v1/vector/BruteForceRowIdIterator.java | 123 +++++ .../index/sai/disk/v1/vector/DiskAnn.java | 123 ++--- .../v1/vector/NeighborQueueRowIdIterator.java | 45 ++ .../NodeScoreToRowIdWithScoreIterator.java | 81 ++++ .../index/sai/disk/v1/vector/OnHeapGraph.java | 81 ++-- .../index/sai/disk/v1/vector/OptimizeFor.java | 2 +- .../disk/v1/vector/PrimaryKeyWithScore.java | 91 ++++ .../RowIdToPrimaryKeyWithScoreIterator.java | 69 +++ .../sai/disk/v1/vector/RowIdWithScore.java | 48 ++ .../v1/vector/SegmentRowIdOrdinalPairs.java | 131 +++++ .../sai/iterators/KeyRangeListIterator.java | 67 --- .../iterators/KeyRangeOrderingIterator.java | 94 ---- .../sai/iterators/PriorityQueueIterator.java | 47 ++ .../index/sai/memory/MemtableIndex.java | 27 +- .../sai/memory/MemtableIndexManager.java | 22 +- .../index/sai/memory/MemtableOrdering.java | 31 +- .../index/sai/memory/TrieMemoryIndex.java | 15 + .../index/sai/memory/VectorMemoryIndex.java | 190 +++++--- .../cassandra/index/sai/plan/Operation.java | 28 +- .../index/sai/plan/QueryController.java | 225 +++++---- ...terializesTooManyPrimaryKeysException.java | 29 ++ .../index/sai/plan/QueryViewBuilder.java | 19 +- .../plan/StorageAttachedIndexQueryPlan.java | 2 +- .../plan/StorageAttachedIndexSearcher.java | 425 ++++++++++++---- .../index/sai/plan/VectorTopKProcessor.java | 49 +- .../index/sai/utils/CellWithSource.java | 245 ++++++++++ .../MergePrimaryKeyWithScoreIterator.java | 70 +++ .../index/sai/utils/RowWithSource.java | 390 +++++++++++++++ .../index/sai/view/IndexViewManager.java | 21 +- .../apache/cassandra/index/sai/view/View.java | 4 +- .../sai/virtual/SSTableIndexesSystemView.java | 6 + .../cassandra/index/sasi/SASIIndex.java | 2 +- .../apache/cassandra/io/sstable/SSTable.java | 5 + .../MemtableSwitchedNotification.java | 8 +- .../org/apache/cassandra/cql3/CQLTester.java | 36 ++ .../cassandra/db/lifecycle/TrackerTest.java | 6 +- .../apache/cassandra/index/sai/SAITester.java | 10 +- .../sai/cql/StorageAttachedIndexDDLTest.java | 8 +- .../index/sai/cql/VectorSiftSmallTest.java | 143 +++++- .../cassandra/index/sai/cql/VectorTester.java | 39 +- .../index/sai/cql/VectorTypeTest.java | 45 ++ .../index/sai/cql/VectorUpdateDeleteTest.java | 456 ++++++++++++++++-- .../disk/v1/InvertedIndexSearcherTest.java | 8 + .../bbtree/BlockBalancedTreeIndexBuilder.java | 9 +- .../index/sai/functional/FlushingTest.java | 4 +- .../sai/functional/GroupComponentsTest.java | 3 +- .../sai/memory/VectorMemoryIndexTest.java | 56 ++- 67 files changed, 3594 insertions(+), 1126 deletions(-) delete mode 100644 src/java/org/apache/cassandra/index/sai/VectorQueryContext.java create mode 100644 src/java/org/apache/cassandra/index/sai/disk/EmptyIndex.java delete mode 100644 src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java create mode 100644 src/java/org/apache/cassandra/index/sai/disk/v1/vector/AutoResumingNodeScoreIterator.java create mode 100644 src/java/org/apache/cassandra/index/sai/disk/v1/vector/BruteForceRowIdIterator.java create mode 100644 src/java/org/apache/cassandra/index/sai/disk/v1/vector/NeighborQueueRowIdIterator.java create mode 100644 src/java/org/apache/cassandra/index/sai/disk/v1/vector/NodeScoreToRowIdWithScoreIterator.java create mode 100644 src/java/org/apache/cassandra/index/sai/disk/v1/vector/PrimaryKeyWithScore.java create mode 100644 src/java/org/apache/cassandra/index/sai/disk/v1/vector/RowIdToPrimaryKeyWithScoreIterator.java create mode 100644 src/java/org/apache/cassandra/index/sai/disk/v1/vector/RowIdWithScore.java create mode 100644 src/java/org/apache/cassandra/index/sai/disk/v1/vector/SegmentRowIdOrdinalPairs.java delete mode 100644 src/java/org/apache/cassandra/index/sai/iterators/KeyRangeListIterator.java delete mode 100644 src/java/org/apache/cassandra/index/sai/iterators/KeyRangeOrderingIterator.java create mode 100644 src/java/org/apache/cassandra/index/sai/iterators/PriorityQueueIterator.java create mode 100644 src/java/org/apache/cassandra/index/sai/plan/QueryMaterializesTooManyPrimaryKeysException.java create mode 100644 src/java/org/apache/cassandra/index/sai/utils/CellWithSource.java create mode 100644 src/java/org/apache/cassandra/index/sai/utils/MergePrimaryKeyWithScoreIterator.java create mode 100644 src/java/org/apache/cassandra/index/sai/utils/RowWithSource.java diff --git a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java index cf514b7dfce1..3eb3aadb7a75 100644 --- a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java +++ b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java @@ -505,15 +505,21 @@ public enum CassandraRelevantProperties /** Whether to allow the user to specify custom options to the hnsw index */ SAI_VECTOR_ALLOW_CUSTOM_PARAMETERS("cassandra.sai.vector.allow_custom_parameters", "false"), - /** Controls the maximum top-k limit for vector search */ - SAI_VECTOR_SEARCH_MAX_TOP_K("cassandra.sai.vector_search.max_top_k", "1000"), - /** - * Controls the maximum number of PrimaryKeys that will be read into memory at one time when ordering/limiting - * the results of an ANN query constrained by non-ANN predicates. + * The maximum number of primary keys that a WHERE clause may materialize before the query planner switches + * from a search-then-sort execution strategy to an order-by-then-filter strategy. Increasing this limit allows + * more primary keys to be buffered in memory, enabling either (a) brute-force sorting or (b) graph traversal + * with a restrictive filter that admits only nodes whose primary keys matched the WHERE clause. + * + * Note also that the SAI_INTERSECTION_CLAUSE_LIMIT is applied to the WHERE clause before using a search to + * build a potential result set for search-then-sort query execution. */ - SAI_VECTOR_SEARCH_ORDER_CHUNK_SIZE("cassandra.sai.vector_search.order_chunk_size", "100000"), + SAI_VECTOR_SEARCH_MAX_MATERIALIZE_KEYS("cassandra.sai.vector_search.max_materialized_keys", "16000"), + + /** Controls the maximum top-k limit for vector search */ + SAI_VECTOR_SEARCH_MAX_TOP_K("cassandra.sai.vector_search.max_top_k", "1000"), + SCHEMA_PULL_INTERVAL_MS("cassandra.schema_pull_interval_ms", "60000"), SCHEMA_UPDATE_HANDLER_FACTORY_CLASS("cassandra.schema.update_handler_factory.class"), SEARCH_CONCURRENCY_FACTOR("cassandra.search_concurrency_factor", "1"), diff --git a/src/java/org/apache/cassandra/db/SinglePartitionReadCommand.java b/src/java/org/apache/cassandra/db/SinglePartitionReadCommand.java index e6c7d1ae62d8..053a0f132957 100644 --- a/src/java/org/apache/cassandra/db/SinglePartitionReadCommand.java +++ b/src/java/org/apache/cassandra/db/SinglePartitionReadCommand.java @@ -56,6 +56,7 @@ import org.apache.cassandra.db.partitions.PartitionIterators; import org.apache.cassandra.db.partitions.SingletonUnfilteredPartitionIterator; import org.apache.cassandra.db.partitions.UnfilteredPartitionIterator; +import org.apache.cassandra.db.rows.BaseRowIterator; import org.apache.cassandra.db.rows.Cell; import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.db.rows.Rows; @@ -726,10 +727,26 @@ public UnfilteredRowIterator queryMemtableAndDisk(ColumnFamilyStore cfs, ReadExe assert executionController != null && executionController.validForReadOn(cfs); Tracing.trace("Executing single-partition query on {}", cfs.name); - return queryMemtableAndDiskInternal(cfs, executionController); + Tracing.trace("Acquiring sstable references"); + ColumnFamilyStore.ViewFragment view = cfs.select(View.select(SSTableSet.LIVE, partitionKey())); + return queryMemtableAndDiskInternal(cfs, view, null, executionController); + } + + public UnfilteredRowIterator queryMemtableAndDisk(ColumnFamilyStore cfs, + ColumnFamilyStore.ViewFragment view, + Function>> rowTransformer, + ReadExecutionController executionController) + { + assert executionController != null && executionController.validForReadOn(cfs); + Tracing.trace("Executing single-partition query on {}", cfs.name); + + return queryMemtableAndDiskInternal(cfs, view, rowTransformer, executionController); } - private UnfilteredRowIterator queryMemtableAndDiskInternal(ColumnFamilyStore cfs, ReadExecutionController controller) + private UnfilteredRowIterator queryMemtableAndDiskInternal(ColumnFamilyStore cfs, + ColumnFamilyStore.ViewFragment view, + Function>> rowTransformer, + ReadExecutionController controller) { /* * We have 2 main strategies: @@ -753,11 +770,9 @@ private UnfilteredRowIterator queryMemtableAndDiskInternal(ColumnFamilyStore cfs && !queriesMulticellType() && !controller.isTrackingRepairedStatus()) { - return queryMemtableAndSSTablesInTimestampOrder(cfs, (ClusteringIndexNamesFilter)clusteringIndexFilter(), controller); + return queryMemtableAndSSTablesInTimestampOrder(cfs, view, rowTransformer, (ClusteringIndexNamesFilter)clusteringIndexFilter(), controller); } - Tracing.trace("Acquiring sstable references"); - ColumnFamilyStore.ViewFragment view = cfs.select(View.select(SSTableSet.LIVE, partitionKey())); view.sstables.sort(SSTableReader.maxTimestampDescending); ClusteringIndexFilter filter = clusteringIndexFilter(); long minTimestamp = Long.MAX_VALUE; @@ -776,6 +791,9 @@ private UnfilteredRowIterator queryMemtableAndDiskInternal(ColumnFamilyStore cfs if (memtable.getMinTimestamp() != Memtable.NO_MIN_TIMESTAMP) minTimestamp = Math.min(minTimestamp, memtable.getMinTimestamp()); + if (rowTransformer != null) + iter = Transformation.apply(iter, rowTransformer.apply(memtable)); + // Memtable data is always considered unrepaired controller.updateMinOldestUnrepairedTombstone(memtable.getMinLocalDeletionTime()); inputCollector.addMemtableIterator(RTBoundValidator.validate(iter, RTBoundValidator.Stage.MEMTABLE, false)); @@ -835,6 +853,9 @@ private UnfilteredRowIterator queryMemtableAndDiskInternal(ColumnFamilyStore cfs UnfilteredRowIterator iter = intersects ? makeRowIteratorWithLowerBound(cfs, sstable, metricsCollector) : makeRowIteratorWithSkippedNonStaticContent(cfs, sstable, metricsCollector); + if (rowTransformer != null) + iter = Transformation.apply(iter, rowTransformer.apply(sstable.getId())); + inputCollector.addSSTableIterator(sstable, iter); mostRecentPartitionTombstone = Math.max(mostRecentPartitionTombstone, iter.partitionLevelDeletion().markedForDeleteAt()); @@ -857,6 +878,10 @@ private UnfilteredRowIterator queryMemtableAndDiskInternal(ColumnFamilyStore cfs { if (!sstable.isRepaired()) controller.updateMinOldestUnrepairedTombstone(sstable.getMinLocalDeletionTime()); + + if (rowTransformer != null) + iter = Transformation.apply(iter, rowTransformer.apply(sstable.getId())); + inputCollector.addSSTableIterator(sstable, iter); includedDueToTombstones++; mostRecentPartitionTombstone = Math.max(mostRecentPartitionTombstone, @@ -996,11 +1021,8 @@ private boolean queriesMulticellType() * no collection or counters are included). * This method assumes the filter is a {@code ClusteringIndexNamesFilter}. */ - private UnfilteredRowIterator queryMemtableAndSSTablesInTimestampOrder(ColumnFamilyStore cfs, ClusteringIndexNamesFilter filter, ReadExecutionController controller) + private UnfilteredRowIterator queryMemtableAndSSTablesInTimestampOrder(ColumnFamilyStore cfs, ColumnFamilyStore.ViewFragment view, Function>> rowTransformer, ClusteringIndexNamesFilter filter, ReadExecutionController controller) { - Tracing.trace("Acquiring sstable references"); - ColumnFamilyStore.ViewFragment view = cfs.select(View.select(SSTableSet.LIVE, partitionKey())); - ImmutableBTreePartition result = null; SSTableReadMetricsCollector metricsCollector = new SSTableReadMetricsCollector(); @@ -1012,7 +1034,9 @@ private UnfilteredRowIterator queryMemtableAndSSTablesInTimestampOrder(ColumnFam if (iter == null) continue; - result = add(RTBoundValidator.validate(iter, RTBoundValidator.Stage.MEMTABLE, false), + UnfilteredRowIterator wrapped = rowTransformer != null ? Transformation.apply(iter, rowTransformer.apply(memtable)) + : iter; + result = add(RTBoundValidator.validate(wrapped, RTBoundValidator.Stage.MEMTABLE, false), result, filter, false, @@ -1067,7 +1091,10 @@ private UnfilteredRowIterator queryMemtableAndSSTablesInTimestampOrder(ColumnFam } else { - result = add(RTBoundValidator.validate(iter, RTBoundValidator.Stage.SSTABLE, false), + UnfilteredRowIterator wrapped = rowTransformer != null ? Transformation.apply(iter, rowTransformer.apply(sstable.getId())) + : iter; + + result = add(RTBoundValidator.validate(wrapped, RTBoundValidator.Stage.SSTABLE, false), result, filter, sstable.isRepaired(), @@ -1082,8 +1109,9 @@ private UnfilteredRowIterator queryMemtableAndSSTablesInTimestampOrder(ColumnFam { if (iter.isEmpty()) continue; - - result = add(RTBoundValidator.validate(iter, RTBoundValidator.Stage.SSTABLE, false), + UnfilteredRowIterator wrapped = rowTransformer != null ? Transformation.apply(iter, rowTransformer.apply(sstable.getId())) + : iter; + result = add(RTBoundValidator.validate(wrapped, RTBoundValidator.Stage.SSTABLE, false), result, filter, sstable.isRepaired(), diff --git a/src/java/org/apache/cassandra/db/lifecycle/Tracker.java b/src/java/org/apache/cassandra/db/lifecycle/Tracker.java index a443b38fff1c..b207d86165eb 100644 --- a/src/java/org/apache/cassandra/db/lifecycle/Tracker.java +++ b/src/java/org/apache/cassandra/db/lifecycle/Tracker.java @@ -419,7 +419,7 @@ public Memtable switchMemtable(boolean truncating, Memtable newMemtable) if (truncating) notifyRenewed(newMemtable); else - notifySwitched(result.left.getCurrentMemtable()); + notifySwitched(result.left.getCurrentMemtable(), result.right.getCurrentMemtable()); return result.left.getCurrentMemtable(); } @@ -577,9 +577,9 @@ public void notifyRenewed(Memtable renewed) notify(new MemtableRenewedNotification(renewed)); } - public void notifySwitched(Memtable previous) + public void notifySwitched(Memtable previous, Memtable next) { - notify(new MemtableSwitchedNotification(previous)); + notify(new MemtableSwitchedNotification(previous, next)); } public void notifyDiscarded(Memtable discarded) diff --git a/src/java/org/apache/cassandra/index/sai/QueryContext.java b/src/java/org/apache/cassandra/index/sai/QueryContext.java index 41ff703bb8c7..d5f03267a9f8 100644 --- a/src/java/org/apache/cassandra/index/sai/QueryContext.java +++ b/src/java/org/apache/cassandra/index/sai/QueryContext.java @@ -71,8 +71,6 @@ public class QueryContext * */ public boolean hasUnrepairedMatches = false; - private VectorQueryContext vectorContext; - public QueryContext(ReadCommand readCommand, long executionQuotaMs) { this.readCommand = readCommand; @@ -94,10 +92,8 @@ public void checkpoint() } } - public VectorQueryContext vectorContext() + public int limit() { - if (vectorContext == null) - vectorContext = new VectorQueryContext(readCommand); - return vectorContext; + return readCommand.limits().count(); } } diff --git a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndexGroup.java b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndexGroup.java index 046af69618cc..748f32c2cc7e 100644 --- a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndexGroup.java +++ b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndexGroup.java @@ -63,6 +63,7 @@ import org.apache.cassandra.notifications.INotificationConsumer; import org.apache.cassandra.notifications.MemtableDiscardedNotification; import org.apache.cassandra.notifications.MemtableRenewedNotification; +import org.apache.cassandra.notifications.MemtableSwitchedNotification; import org.apache.cassandra.notifications.SSTableAddedNotification; import org.apache.cassandra.notifications.SSTableListChangedNotification; import org.apache.cassandra.schema.TableMetadata; @@ -277,6 +278,10 @@ else if (notification instanceof MemtableRenewedNotification) { indexes.forEach(index -> index.memtableIndexManager().renewMemtable(((MemtableRenewedNotification) notification).renewed)); } + else if (notification instanceof MemtableSwitchedNotification) + { + indexes.forEach(index -> index.memtableIndexManager().maybeInitializeMemtableIndex(((MemtableSwitchedNotification) notification).next)); + } else if (notification instanceof MemtableDiscardedNotification) { indexes.forEach(index -> index.memtableIndexManager().discardMemtable(((MemtableDiscardedNotification) notification).memtable)); diff --git a/src/java/org/apache/cassandra/index/sai/VectorQueryContext.java b/src/java/org/apache/cassandra/index/sai/VectorQueryContext.java deleted file mode 100644 index 499af1601b7f..000000000000 --- a/src/java/org/apache/cassandra/index/sai/VectorQueryContext.java +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.index.sai; - -import java.io.IOException; -import java.util.Collection; -import java.util.Collections; -import java.util.HashSet; -import java.util.NavigableSet; -import java.util.Set; -import java.util.TreeSet; - -import org.apache.cassandra.db.ReadCommand; -import org.apache.cassandra.index.sai.disk.PrimaryKeyMap; -import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata; -import org.apache.cassandra.index.sai.disk.v1.vector.DiskAnn; -import org.apache.cassandra.index.sai.disk.v1.vector.OnHeapGraph; -import org.apache.cassandra.index.sai.utils.PrimaryKey; - -import io.github.jbellis.jvector.util.Bits; - - -/** - * This represents the state of a vector query. It is repsonsible for maintaining a list of any {@link PrimaryKey}s - * that have been updated or deleted during a search of the indexes. - *

- * The number of {@link #shadowedPrimaryKeys} is compared before and after a search is performed. If it changes, it - * means that a {@link PrimaryKey} was found to have been changed. In this case the whole search is repeated until the - * counts match. - *

- * When this process has completed, a {@link Bits} array is generated. This is used by the vector graph search to - * identify which nodes in the graph to include in the results. - */ -public class VectorQueryContext -{ - private final int limit; - // Holds primary keys that are shadowed by expired TTL or row tombstone or range tombstone. - // They are populated by the StorageAttachedIndexSearcher during filtering. They are used to generate - // a bitset for the graph search to indicate graph nodes to ignore. - private TreeSet shadowedPrimaryKeys; - - public VectorQueryContext(ReadCommand readCommand) - { - this.limit = readCommand.limits().count(); - } - - public int limit() - { - return limit; - } - - public void recordShadowedPrimaryKeys(Set keys) - { - if (shadowedPrimaryKeys == null) - shadowedPrimaryKeys = new TreeSet<>(); - shadowedPrimaryKeys.addAll(keys); - } - - // Returns true if the row ID will be included or false if the row ID will be shadowed - public boolean shouldInclude(long sstableRowId, PrimaryKeyMap primaryKeyMap) - { - return shadowedPrimaryKeys == null || !shadowedPrimaryKeys.contains(primaryKeyMap.primaryKeyFromRowId(sstableRowId)); - } - - public boolean shouldInclude(PrimaryKey pk) - { - return shadowedPrimaryKeys == null || !shadowedPrimaryKeys.contains(pk); - } - - public boolean containsShadowedPrimaryKey(PrimaryKey primaryKey) - { - return shadowedPrimaryKeys != null && shadowedPrimaryKeys.contains(primaryKey); - } - - /** - * @return shadowed primary keys, in ascending order - */ - public NavigableSet getShadowedPrimaryKeys() - { - if (shadowedPrimaryKeys == null) - return Collections.emptyNavigableSet(); - return shadowedPrimaryKeys; - } - - public Bits bitsetForShadowedPrimaryKeys(OnHeapGraph graph) - { - if (shadowedPrimaryKeys == null) - return null; - - return new IgnoredKeysBits(graph, shadowedPrimaryKeys); - } - - public Bits bitsetForShadowedPrimaryKeys(SegmentMetadata metadata, PrimaryKeyMap primaryKeyMap, DiskAnn graph) throws IOException - { - Set ignoredOrdinals = null; - try (var ordinalsView = graph.getOrdinalsView()) - { - for (PrimaryKey primaryKey : getShadowedPrimaryKeys()) - { - // not in current segment - if (primaryKey.compareTo(metadata.minKey) < 0 || primaryKey.compareTo(metadata.maxKey) > 0) - continue; - - long sstableRowId = primaryKeyMap.rowIdFromPrimaryKey(primaryKey); - if (sstableRowId == Long.MAX_VALUE) // not found - continue; - - int segmentRowId = Math.toIntExact(sstableRowId - metadata.rowIdOffset); - // not in segment yet - if (segmentRowId < 0) - continue; - // end of segment - if (segmentRowId > metadata.maxSSTableRowId) - break; - - int ordinal = ordinalsView.getOrdinalForRowId(segmentRowId); - if (ordinal >= 0) - { - if (ignoredOrdinals == null) - ignoredOrdinals = new HashSet<>(); - ignoredOrdinals.add(ordinal); - } - } - } - - if (ignoredOrdinals == null) - return null; - - return new IgnoringBits(ignoredOrdinals, metadata); - } - - private static class IgnoringBits implements Bits - { - private final Set ignoredOrdinals; - private final int length; - - public IgnoringBits(Set ignoredOrdinals, SegmentMetadata metadata) - { - this.ignoredOrdinals = ignoredOrdinals; - this.length = 1 + Math.toIntExact(metadata.maxSSTableRowId - metadata.rowIdOffset); - } - - @Override - public boolean get(int index) - { - return !ignoredOrdinals.contains(index); - } - - @Override - public int length() - { - return length; - } - } - - private static class IgnoredKeysBits implements Bits - { - private final OnHeapGraph graph; - private final NavigableSet ignored; - - public IgnoredKeysBits(OnHeapGraph graph, NavigableSet ignored) - { - this.graph = graph; - this.ignored = ignored; - } - - @Override - public boolean get(int ordinal) - { - Collection keys = graph.keysFromOrdinal(ordinal); - return keys.stream().anyMatch(k -> !ignored.contains(k)); - } - - @Override - public int length() - { - return graph.size(); - } - } -} diff --git a/src/java/org/apache/cassandra/index/sai/disk/EmptyIndex.java b/src/java/org/apache/cassandra/index/sai/disk/EmptyIndex.java new file mode 100644 index 000000000000..4f4046b3290a --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/EmptyIndex.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; + +import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.virtual.SimpleDataSet; +import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.index.sai.QueryContext; +import org.apache.cassandra.index.sai.SSTableContext; +import org.apache.cassandra.index.sai.StorageAttachedIndex; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; +import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; +import org.apache.cassandra.index.sai.plan.Expression; +import org.apache.cassandra.index.sai.utils.PrimaryKey; +import org.apache.cassandra.utils.CloseableIterator; + +/** + * A placeholder index for when there is no on-disk index. + * + * Currenly only used by vector indexes becasue ANN queries require a complete view of the table's sstables, even if + * the associated sstable does not have any data indexed for the column. + */ +public class EmptyIndex extends SSTableIndex +{ + public EmptyIndex(SSTableContext sstableContext, StorageAttachedIndex index) + { + super(sstableContext, index); + } + + @Override + public long indexFileCacheSize() + { + return 0; + } + + @Override + public long getRowCount() + { + return 0; + } + + @Override + public long minSSTableRowId() + { + return -1; + } + + @Override + public long maxSSTableRowId() + { + return -1; + } + + @Override + public ByteBuffer minTerm() + { + return null; + } + + @Override + public ByteBuffer maxTerm() + { + return null; + } + + @Override + public AbstractBounds bounds() + { + return null; + } + + @Override + public List search(Expression expression, AbstractBounds keyRange, QueryContext context) throws IOException + { + return List.of(); + } + + @Override + public List> orderBy(Expression orderer, AbstractBounds keyRange, QueryContext context) throws IOException + { + return List.of(); + } + + @Override + public List> orderResultsBy(QueryContext context, List results, Expression orderer) throws IOException + { + return List.of(); + } + + @Override + public void populateSegmentView(SimpleDataSet dataSet) + { + + } + + @Override + protected void internalRelease() + { + + } +} \ No newline at end of file diff --git a/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyMap.java b/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyMap.java index 88e72efab7fa..b2ab64fdc25f 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyMap.java +++ b/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyMap.java @@ -26,6 +26,7 @@ import org.apache.cassandra.dht.Token; import org.apache.cassandra.index.sai.utils.PrimaryKey; +import org.apache.cassandra.io.sstable.SSTableId; /** * A bidirectional map of {@link PrimaryKey} to row ID. Implementations of this interface @@ -55,6 +56,12 @@ default void close() } } + /** + * Returns the {@link SSTableId} associated with this {@link PrimaryKeyMap} + * @return an {@link SSTableId} + */ + SSTableId getSSTableId(); + /** * Returns a {@link PrimaryKey} for a row ID * diff --git a/src/java/org/apache/cassandra/index/sai/disk/SSTableIndex.java b/src/java/org/apache/cassandra/index/sai/disk/SSTableIndex.java index f887a818ca04..087d9288c8e5 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/SSTableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/disk/SSTableIndex.java @@ -38,13 +38,15 @@ import org.apache.cassandra.index.sai.SSTableContext; import org.apache.cassandra.index.sai.StorageAttachedIndex; import org.apache.cassandra.index.sai.disk.format.Version; -import org.apache.cassandra.index.sai.disk.v1.segment.SegmentOrdering; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.utils.IndexIdentifier; import org.apache.cassandra.index.sai.utils.IndexTermType; +import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.io.sstable.SSTableIdFactory; import org.apache.cassandra.io.sstable.format.SSTableReader; +import org.apache.cassandra.utils.CloseableIterator; /** * A reference-counted container of a {@link SSTableReader} for each column index that: @@ -54,7 +56,7 @@ *

  • Exposes the index metadata for the column index
  • * */ -public abstract class SSTableIndex implements SegmentOrdering, Comparable +public abstract class SSTableIndex implements Comparable { private static final Logger logger = LoggerFactory.getLogger(SSTableIndex.class); @@ -144,6 +146,9 @@ public abstract List search(Expression expression, AbstractBounds keyRange, QueryContext context) throws IOException; + public abstract List> orderBy(Expression orderer, AbstractBounds keyRange, QueryContext context) throws IOException; + public abstract List> orderResultsBy(QueryContext context, List results, Expression orderer) throws IOException; + /** * Populates a virtual table using the index metadata owned by the index */ diff --git a/src/java/org/apache/cassandra/index/sai/disk/format/IndexDescriptor.java b/src/java/org/apache/cassandra/index/sai/disk/format/IndexDescriptor.java index 4d94084cc33a..1d9ee774ded7 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/format/IndexDescriptor.java +++ b/src/java/org/apache/cassandra/index/sai/disk/format/IndexDescriptor.java @@ -38,6 +38,7 @@ import org.apache.cassandra.index.sai.IndexValidation; import org.apache.cassandra.index.sai.SSTableContext; import org.apache.cassandra.index.sai.StorageAttachedIndex; +import org.apache.cassandra.index.sai.disk.EmptyIndex; import org.apache.cassandra.index.sai.disk.PerColumnIndexWriter; import org.apache.cassandra.index.sai.disk.PerSSTableIndexWriter; import org.apache.cassandra.index.sai.disk.PrimaryKeyMap; @@ -127,7 +128,9 @@ public PrimaryKeyMap.Factory newPrimaryKeyMapFactory(SSTableReader sstable) public SSTableIndex newSSTableIndex(SSTableContext sstableContext, StorageAttachedIndex index) { - return version.onDiskFormat().newSSTableIndex(sstableContext, index); + return isIndexEmpty(index.termType(), index.identifier()) + ? new EmptyIndex(sstableContext, index) + : version.onDiskFormat().newSSTableIndex(sstableContext, index); } public PerSSTableIndexWriter newPerSSTableIndexWriter() throws IOException diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SkinnyPrimaryKeyMap.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SkinnyPrimaryKeyMap.java index ab02c1c3cd4e..844bfe0c7363 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/SkinnyPrimaryKeyMap.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SkinnyPrimaryKeyMap.java @@ -35,6 +35,7 @@ import org.apache.cassandra.index.sai.disk.v1.keystore.KeyLookup; import org.apache.cassandra.index.sai.disk.v1.keystore.KeyLookupMeta; import org.apache.cassandra.index.sai.utils.PrimaryKey; +import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.io.util.FileHandle; import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.utils.Throwables; @@ -67,6 +68,7 @@ public static class Factory implements PrimaryKeyMap.Factory protected final LongArray.Factory rowToPartitionReaderFactory; protected final KeyLookup partitionKeyReader; protected final PrimaryKey.Factory primaryKeyFactory; + protected final SSTableId sstableId; private final FileHandle rowToTokenFile; private final FileHandle rowToPartitionFile; @@ -90,6 +92,7 @@ public Factory(IndexDescriptor indexDescriptor) KeyLookupMeta partitionKeysMeta = new KeyLookupMeta(metadataSource.get(indexDescriptor.componentName(IndexComponent.PARTITION_KEY_BLOCKS))); this.partitionKeyReader = new KeyLookup(partitionKeyBlocksFile, partitionKeyBlockOffsetsFile, partitionKeysMeta, partitionKeyBlockOffsetsMeta); this.primaryKeyFactory = indexDescriptor.primaryKeyFactory; + this.sstableId = indexDescriptor.sstableDescriptor.id; } catch (Throwable t) { @@ -106,7 +109,8 @@ public PrimaryKeyMap newPerSSTablePrimaryKeyMap() throws IOException return new SkinnyPrimaryKeyMap(rowIdToToken, rowIdToPartitionId, partitionKeyReader.openCursor(), - primaryKeyFactory); + primaryKeyFactory, + sstableId); } @Override @@ -120,16 +124,25 @@ public void close() protected final LongArray rowIdToPartitionIdArray; protected final KeyLookup.Cursor partitionKeyCursor; protected final PrimaryKey.Factory primaryKeyFactory; + protected final SSTableId sstableId; protected SkinnyPrimaryKeyMap(LongArray rowIdToTokenArray, LongArray rowIdToPartitionIdArray, KeyLookup.Cursor partitionKeyCursor, - PrimaryKey.Factory primaryKeyFactory) + PrimaryKey.Factory primaryKeyFactory, + SSTableId sstableId) { this.rowIdToTokenArray = rowIdToTokenArray; this.rowIdToPartitionIdArray = rowIdToPartitionIdArray; this.partitionKeyCursor = partitionKeyCursor; this.primaryKeyFactory = primaryKeyFactory; + this.sstableId = sstableId; + } + + @Override + public SSTableId getSSTableId() + { + return sstableId; } @Override diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/V1SSTableIndex.java b/src/java/org/apache/cassandra/index/sai/disk/v1/V1SSTableIndex.java index 0945444a0067..044529aac3f4 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/V1SSTableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/V1SSTableIndex.java @@ -36,12 +36,13 @@ import org.apache.cassandra.index.sai.disk.SSTableIndex; import org.apache.cassandra.index.sai.disk.v1.segment.Segment; import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; -import org.apache.cassandra.index.sai.iterators.KeyRangeUnionIterator; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.Throwables; import static org.apache.cassandra.index.sai.virtual.SegmentsSystemView.CELL_COUNT; @@ -175,14 +176,28 @@ public List search(Expression expression, return segmentIterators; } - @Override - public KeyRangeIterator limitToTopKResults(QueryContext context, List primaryKeys, Expression expression) throws IOException + public List> orderBy(Expression orderer, AbstractBounds keyRange, QueryContext context) throws IOException + { + // Return a list to allow the caller to merge the results from multiple sstables into a single iterator. + List> iterators = new ArrayList<>(segments.size()); + for (Segment segment : segments) + { + if (segment.intersects(keyRange)) + { + iterators.add(segment.orderBy(orderer, keyRange, context)); + } + } + return iterators; + } + + public List> orderResultsBy(QueryContext context, List results, Expression orderer) throws IOException { - KeyRangeUnionIterator.Builder unionIteratorBuilder = KeyRangeUnionIterator.builder(segments.size()); + // Return a list to allow the caller to merge the results from multiple sstables into a single iterator. + List> iterators = new ArrayList<>(segments.size()); for (Segment segment : segments) - unionIteratorBuilder.add(segment.limitToTopKResults(context, primaryKeys, expression)); + iterators.add(segment.orderResultsBy(context, results, orderer)); - return unionIteratorBuilder.build(); + return iterators; } @Override diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/WidePrimaryKeyMap.java b/src/java/org/apache/cassandra/index/sai/disk/v1/WidePrimaryKeyMap.java index c6e7737cf06e..8f4a3097dd6a 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/WidePrimaryKeyMap.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/WidePrimaryKeyMap.java @@ -35,6 +35,7 @@ import org.apache.cassandra.index.sai.disk.v1.keystore.KeyLookup; import org.apache.cassandra.index.sai.disk.v1.keystore.KeyLookupMeta; import org.apache.cassandra.index.sai.utils.PrimaryKey; +import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.io.util.FileHandle; import org.apache.cassandra.io.util.FileUtils; @@ -104,7 +105,8 @@ public PrimaryKeyMap newPerSSTablePrimaryKeyMap() throws IOException partitionKeyReader.openCursor(), clusteringKeyReader.openCursor(), primaryKeyFactory, - clusteringComparator); + clusteringComparator, + sstableId); } @Override @@ -125,9 +127,10 @@ private WidePrimaryKeyMap(LongArray rowIdToTokenArray, KeyLookup.Cursor partitionKeyCursor, KeyLookup.Cursor clusteringKeyCursor, PrimaryKey.Factory primaryKeyFactory, - ClusteringComparator clusteringComparator) + ClusteringComparator clusteringComparator, + SSTableId sstableId) { - super(rowIdToTokenArray, rowIdToPartitionIdArray, partitionKeyCursor, primaryKeyFactory); + super(rowIdToTokenArray, rowIdToPartitionIdArray, partitionKeyCursor, primaryKeyFactory, sstableId); this.partitionIdToSizeArray = partitionIdToSizeArray; this.clusteringComparator = clusteringComparator; diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java deleted file mode 100644 index fefc0a3a34c5..000000000000 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.index.sai.disk.v1.postings; - -import java.io.IOException; -import java.util.PrimitiveIterator; - -import org.apache.lucene.util.LongHeap; - -import org.apache.cassandra.index.sai.postings.PostingList; - -/** - * A {@link PostingList} for ANN search results. Transforms result from similarity order to row ID order. - */ -public class VectorPostingList implements PostingList -{ - private final LongHeap segmentRowIds; - private final int size; - private final int visitedCount; - - public VectorPostingList(PrimitiveIterator.OfInt source, int limit, int visitedCount) - { - this.visitedCount = visitedCount; - segmentRowIds = new LongHeap(Math.max(limit, 1)); - int n = 0; - while (source.hasNext() && n++ < limit) - segmentRowIds.push(source.nextInt()); - this.size = n; - } - - @Override - public long nextPosting() - { - if (segmentRowIds.size() == 0) - return PostingList.END_OF_STREAM; - return segmentRowIds.pop(); - } - - @Override - public long size() - { - return size; - } - - @Override - public long advance(long targetRowID) throws IOException - { - long rowId; - do - { - rowId = nextPosting(); - } while (rowId < targetRowID); - return rowId; - } - - public int getVisitedCount() - { - return visitedCount; - } -} diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/segment/IndexSegmentSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/segment/IndexSegmentSearcher.java index 96f389d794c6..802797a682bd 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/segment/IndexSegmentSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/segment/IndexSegmentSearcher.java @@ -27,10 +27,13 @@ import org.apache.cassandra.index.sai.disk.PrimaryKeyMap; import org.apache.cassandra.index.sai.disk.v1.PerColumnIndexFiles; import org.apache.cassandra.index.sai.disk.v1.postings.PostingListRangeIterator; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.postings.PeekablePostingList; import org.apache.cassandra.index.sai.postings.PostingList; +import org.apache.cassandra.io.sstable.SSTableId; +import org.apache.cassandra.utils.CloseableIterator; /** * Abstract reader for individual segments of an on-disk index. @@ -57,12 +60,13 @@ public abstract class IndexSegmentSearcher implements SegmentOrdering, Closeable } public static IndexSegmentSearcher open(PrimaryKeyMap.Factory primaryKeyMapFactory, + SSTableId sstableId, PerColumnIndexFiles indexFiles, SegmentMetadata segmentMetadata, StorageAttachedIndex index) throws IOException { if (index.termType().isVector()) - return new VectorIndexSegmentSearcher(primaryKeyMapFactory, indexFiles, segmentMetadata, index); + return new VectorIndexSegmentSearcher(primaryKeyMapFactory, sstableId, indexFiles, segmentMetadata, index); else if (index.termType().isLiteral()) return new LiteralIndexSegmentSearcher(primaryKeyMapFactory, indexFiles, segmentMetadata, index); else @@ -84,6 +88,21 @@ else if (index.termType().isLiteral()) */ public abstract KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext) throws IOException; + /** + * Order the rows by the given expression. + * + * @param orderer the object containing the ordering logic + * @param keyRange key range specific in read command, used by ANN index + * @param context to track per sstable cache and per query metrics + * + * @return an iterator of {@link PrimaryKeyWithScore} in descending score order + */ + public CloseableIterator orderBy(Expression orderer, AbstractBounds keyRange, QueryContext context) throws IOException + { + throw new UnsupportedOperationException(); + } + + KeyRangeIterator toPrimaryKeyIterator(PostingList postingList, QueryContext queryContext) throws IOException { if (postingList == null || postingList.size() == 0) diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/segment/Segment.java b/src/java/org/apache/cassandra/index/sai/disk/v1/segment/Segment.java index 6e08551bd108..2354e798020d 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/segment/Segment.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/segment/Segment.java @@ -30,12 +30,13 @@ import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.SSTableContext; import org.apache.cassandra.index.sai.StorageAttachedIndex; -import org.apache.cassandra.index.sai.disk.PrimaryKeyMap; import org.apache.cassandra.index.sai.disk.v1.PerColumnIndexFiles; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.utils.CloseableIterator; /** * Each segment represents an on-disk index structure (balanced tree/terms/postings) flushed by memory limit or token boundaries. @@ -47,8 +48,6 @@ public class Segment implements SegmentOrdering, Closeable private final Token.KeyBound minKeyBound; private final Token.KeyBound maxKeyBound; - // per sstable - final PrimaryKeyMap.Factory primaryKeyMapFactory; // per-segment public final SegmentMetadata metadata; @@ -59,16 +58,14 @@ public Segment(StorageAttachedIndex index, SSTableContext sstableContext, PerCol this.minKeyBound = metadata.minKey.token().minKeyBound(); this.maxKeyBound = metadata.maxKey.token().maxKeyBound(); - this.primaryKeyMapFactory = sstableContext.primaryKeyMapFactory; this.metadata = metadata; - this.index = IndexSegmentSearcher.open(primaryKeyMapFactory, indexFiles, metadata, index); + this.index = IndexSegmentSearcher.open(sstableContext.primaryKeyMapFactory, sstableContext.sstable.getId(), indexFiles, metadata, index); } @VisibleForTesting public Segment(Token minKey, Token maxKey) { - this.primaryKeyMapFactory = null; this.metadata = null; this.minKeyBound = minKey.minKeyBound(); this.maxKeyBound = maxKey.maxKeyBound(); @@ -112,10 +109,23 @@ public KeyRangeIterator search(Expression expression, AbstractBounds orderBy(Expression orderer, AbstractBounds keyRange, QueryContext context) throws IOException + { + return index.orderBy(orderer, keyRange, context); + } + @Override - public KeyRangeIterator limitToTopKResults(QueryContext context, List primaryKeys, Expression expression) throws IOException + public CloseableIterator orderResultsBy(QueryContext context, List results, Expression orderer) throws IOException { - return index.limitToTopKResults(context, primaryKeys, expression); + return index.orderResultsBy(context, results, orderer); } @Override diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/segment/SegmentBuilder.java b/src/java/org/apache/cassandra/index/sai/disk/v1/segment/SegmentBuilder.java index 492cec245ed0..6793638428ae 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/segment/SegmentBuilder.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/segment/SegmentBuilder.java @@ -115,7 +115,7 @@ public static class VectorSegmentBuilder extends SegmentBuilder public VectorSegmentBuilder(StorageAttachedIndex index, NamedMemoryLimiter limiter) { super(index, limiter); - graphIndex = new OnHeapGraph<>(index.termType().indexType(), index.indexWriterConfig(), false); + graphIndex = new OnHeapGraph<>(index.termType().indexType(), index.indexWriterConfig(), null); } @Override diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/segment/SegmentOrdering.java b/src/java/org/apache/cassandra/index/sai/disk/v1/segment/SegmentOrdering.java index 616b0ea86d97..dcb6e4273707 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/segment/SegmentOrdering.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/segment/SegmentOrdering.java @@ -23,24 +23,28 @@ import org.apache.cassandra.dht.AbstractBounds; import org.apache.cassandra.index.sai.QueryContext; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.utils.PrimaryKey; +import org.apache.cassandra.utils.CloseableIterator; /** - * A {@link SegmentOrdering} orders and limits a list of {@link PrimaryKey}s. - *

    + * A {@link SegmentOrdering} orders an index and produces a stream of {@link PrimaryKeyWithScore}s. + * + * The limit can be used to lazily order the {@link PrimaryKey}s. Due to the possiblity for + * shadowed or updated keys, a {@link SegmentOrdering} should be able to order the whole index + * until exhausted. + * * When using {@link SegmentOrdering} there are several steps to - * build the list of Primary Keys to be ordered and limited: - *

    + * build the list of Primary Keys to be ordered: + * * 1. Find all primary keys that match each non-ordering query predicate. * 2. Union and intersect the results of step 1 to build a single {@link KeyRangeIterator} * ordered by {@link PrimaryKey}. - * 3. Filter out any shadowed primary keys. - * 4. Fan the primary keys from step 3 out to each sstable segment to order and limit each - * list of primary keys. + * 3. Fan the primary keys from step 2 out to each sstable segment to order the list of primary keys. *

    - * SegmentOrdering handles the fourth step. + * SegmentOrdering handles the third step. *

    * Note: a segment ordering is only used when a query has both ordering and non-ordering predicates. * Where a query has only ordering predicates, the ordering is handled by @@ -51,7 +55,7 @@ public interface SegmentOrdering /** * Reorder, limit, and put back into original order the results from a single sstable */ - default KeyRangeIterator limitToTopKResults(QueryContext queryContext, List primaryKeys, Expression expression) throws IOException + default CloseableIterator orderResultsBy(QueryContext queryContext, List results, Expression orderer) throws IOException { throw new UnsupportedOperationException(); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/segment/VectorIndexSegmentSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/segment/VectorIndexSegmentSearcher.java index 1875ec7a8b71..7a3937e8af1a 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/segment/VectorIndexSegmentSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/segment/VectorIndexSegmentSearcher.java @@ -20,40 +20,47 @@ import java.io.IOException; import java.lang.invoke.MethodHandles; import java.util.List; +import java.util.function.IntConsumer; import java.util.stream.Collectors; -import javax.annotation.Nullable; - +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; -import com.google.common.base.Preconditions; -import org.agrona.collections.IntArrayList; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.NeighborQueue; +import io.github.jbellis.jvector.graph.NeighborSimilarity; +import io.github.jbellis.jvector.pq.CompressedVectors; +import io.github.jbellis.jvector.util.Bits; +import io.github.jbellis.jvector.util.SparseFixedBitSet; import org.apache.cassandra.db.PartitionPosition; import org.apache.cassandra.dht.AbstractBounds; import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.StorageAttachedIndex; -import org.apache.cassandra.index.sai.VectorQueryContext; import org.apache.cassandra.index.sai.disk.PrimaryKeyMap; import org.apache.cassandra.index.sai.disk.v1.PerColumnIndexFiles; -import org.apache.cassandra.index.sai.disk.v1.postings.VectorPostingList; +import org.apache.cassandra.index.sai.disk.v1.vector.BruteForceRowIdIterator; import org.apache.cassandra.index.sai.disk.v1.vector.DiskAnn; +import org.apache.cassandra.index.sai.disk.v1.vector.NeighborQueueRowIdIterator; +import org.apache.cassandra.index.sai.disk.v1.vector.OnDiskOrdinalsMap; import org.apache.cassandra.index.sai.disk.v1.vector.OptimizeFor; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; +import org.apache.cassandra.index.sai.disk.v1.vector.RowIdToPrimaryKeyWithScoreIterator; +import org.apache.cassandra.index.sai.disk.v1.vector.RowIdWithScore; +import org.apache.cassandra.index.sai.disk.v1.vector.SegmentRowIdOrdinalPairs; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; -import org.apache.cassandra.index.sai.iterators.KeyRangeListIterator; import org.apache.cassandra.index.sai.memory.VectorMemoryIndex; import org.apache.cassandra.index.sai.plan.Expression; -import org.apache.cassandra.index.sai.postings.IntArrayPostingList; -import org.apache.cassandra.index.sai.postings.PostingList; import org.apache.cassandra.index.sai.utils.AtomicRatio; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.RangeUtil; +import org.apache.cassandra.io.sstable.SSTableId; +import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.tracing.Tracing; - -import io.github.jbellis.jvector.util.Bits; -import io.github.jbellis.jvector.util.SparseFixedBitSet; +import org.apache.cassandra.utils.CloseableIterator; import static java.lang.Math.max; import static java.lang.Math.min; @@ -65,22 +72,27 @@ public class VectorIndexSegmentSearcher extends IndexSegmentSearcher { private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + // If true, use brute force. If false, use graph search. If null, use the normal logic. + @VisibleForTesting + public static Boolean FORCE_BRUTE_FORCE_ANN = null; + private final DiskAnn graph; - private final int globalBruteForceRows; private final AtomicRatio actualExpectedRatio = new AtomicRatio(); private final ThreadLocal cachedBitSets; private final OptimizeFor optimizeFor; + private final ColumnMetadata column; VectorIndexSegmentSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, + SSTableId sstableId, PerColumnIndexFiles perIndexFiles, SegmentMetadata segmentMetadata, StorageAttachedIndex index) throws IOException { super(primaryKeyMapFactory, perIndexFiles, segmentMetadata, index); - graph = new DiskAnn(segmentMetadata.componentMetadatas, perIndexFiles, index.indexWriterConfig()); + graph = new DiskAnn(segmentMetadata.componentMetadatas, perIndexFiles, index.indexWriterConfig(), sstableId); cachedBitSets = ThreadLocal.withInitial(() -> new SparseFixedBitSet(graph.size())); - globalBruteForceRows = Integer.MAX_VALUE; optimizeFor = index.indexWriterConfig().getOptimizeFor(); + column = index.termType().columnMetadata(); } @Override @@ -90,75 +102,78 @@ public long indexFileCacheSize() } @Override - public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context) throws IOException + public KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext) throws IOException + { + throw new UnsupportedOperationException(); + } + + @Override + public CloseableIterator orderBy(Expression orderer, AbstractBounds keyRange, QueryContext context) throws IOException { - int limit = context.vectorContext().limit(); + int limit = context.limit(); if (logger.isTraceEnabled()) - logger.trace(index.identifier().logMessage("Searching on expression '{}'..."), exp); + logger.trace(index.identifier().logMessage("Searching on expression '{}'..."), orderer); - if (exp.getIndexOperator() != Expression.IndexOperator.ANN) - throw new IllegalArgumentException(index.identifier().logMessage("Unsupported expression during ANN index query: " + exp)); + if (orderer.getIndexOperator() != Expression.IndexOperator.ANN) + throw new IllegalArgumentException(index.identifier().logMessage("Unsupported expression during ANN index query: " + orderer)); int topK = optimizeFor.topKFor(limit); - BitsOrPostingList bitsOrPostingList = bitsOrPostingListForKeyRange(context.vectorContext(), keyRange, topK); - if (bitsOrPostingList.skipANN()) - return toPrimaryKeyIterator(bitsOrPostingList.postingList(), context); - - float[] queryVector = index.termType().decomposeVector(exp.lower().value.raw.duplicate()); - VectorPostingList vectorPostings = graph.search(queryVector, topK, limit, bitsOrPostingList.getBits()); - if (bitsOrPostingList.expectedNodesVisited >= 0) - updateExpectedNodes(vectorPostings.getVisitedCount(), bitsOrPostingList.expectedNodesVisited); - return toPrimaryKeyIterator(vectorPostings, context); + + float[] queryVector = index.termType().decomposeVector(orderer.lower().value.raw.duplicate()); + CloseableIterator result = searchInternal(keyRange, queryVector, limit, topK); + return toScoreSortedIterator(result); } - /** - * Return bit set we need to search the graph; otherwise return posting list to bypass the graph - */ - private BitsOrPostingList bitsOrPostingListForKeyRange(VectorQueryContext context, AbstractBounds keyRange, int limit) throws IOException + private CloseableIterator searchInternal(AbstractBounds keyRange, float[] queryVector, int limit, int topK) throws IOException { try (PrimaryKeyMap primaryKeyMap = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap()) { // not restricted if (RangeUtil.coversFullRing(keyRange)) - return new BitsOrPostingList(context.bitsetForShadowedPrimaryKeys(metadata, primaryKeyMap, graph)); + return searchInternalUnrestricted(queryVector, limit, topK); + // it will return the next row id if given key is not found. long minSSTableRowId = primaryKeyMap.ceiling(keyRange.left.getToken()); // If we didn't find the first key, we won't find the last primary key either if (minSSTableRowId < 0) - return new BitsOrPostingList(PostingList.EMPTY); + return CloseableIterator.empty(); long maxSSTableRowId = getMaxSSTableRowId(primaryKeyMap, keyRange.right); if (minSSTableRowId > maxSSTableRowId) - return new BitsOrPostingList(PostingList.EMPTY); + return CloseableIterator.empty(); // if it covers entire segment, skip bit set if (minSSTableRowId <= metadata.minSSTableRowId && maxSSTableRowId >= metadata.maxSSTableRowId) - return new BitsOrPostingList(context.bitsetForShadowedPrimaryKeys(metadata, primaryKeyMap, graph)); + return searchInternalUnrestricted(queryVector, limit, topK); minSSTableRowId = Math.max(minSSTableRowId, metadata.minSSTableRowId); maxSSTableRowId = min(maxSSTableRowId, metadata.maxSSTableRowId); - // If num of matches are not bigger than limit, skip ANN. - // (nRows should not include shadowed rows, but context doesn't break those out by segment, - // so we will live with the inaccuracy.) + // If num of matches are not bigger than limit, skip graph search and lazily sort by brute force. int nRows = Math.toIntExact(maxSSTableRowId - minSSTableRowId + 1); - int maxBruteForceRows = min(globalBruteForceRows, maxBruteForceRows(limit, nRows, graph.size())); + int maxBruteForceRows = maxBruteForceRows(limit, nRows, graph.size()); if (logger.isTraceEnabled()) logger.trace("Search range covers {} rows; max brute force rows is {} for sstable index with {} nodes, LIMIT {}", nRows, maxBruteForceRows, graph.size(), limit); Tracing.trace("Search range covers {} rows; max brute force rows is {} for sstable index with {} nodes, LIMIT {}", nRows, maxBruteForceRows, graph.size(), limit); - if (nRows <= maxBruteForceRows) + boolean shouldBruteForce = FORCE_BRUTE_FORCE_ANN == null ? nRows <= maxBruteForceRows : FORCE_BRUTE_FORCE_ANN; + if (shouldBruteForce) { - IntArrayList postings = new IntArrayList(Math.toIntExact(nRows), -1); - for (long sstableRowId = minSSTableRowId; sstableRowId <= maxSSTableRowId; sstableRowId++) + SegmentRowIdOrdinalPairs segmentOrdinalPairs = new SegmentRowIdOrdinalPairs(Math.toIntExact(nRows)); + try (OnDiskOrdinalsMap.OrdinalsView ordinalsView = graph.getOrdinalsView()) { - if (context.shouldInclude(sstableRowId, primaryKeyMap)) - postings.addInt(metadata.toSegmentRowId(sstableRowId)); + for (long sstableRowId = minSSTableRowId; sstableRowId <= maxSSTableRowId; sstableRowId++) + { + int segmentRowId = metadata.toSegmentRowId(sstableRowId); + int ordinal = ordinalsView.getOrdinalForRowId(segmentRowId); + if (ordinal >= 0) + segmentOrdinalPairs.add(segmentRowId, ordinal); + } } - return new BitsOrPostingList(new IntArrayPostingList(postings.toIntArray())); + return orderByBruteForce(queryVector, segmentOrdinalPairs, limit, topK); } // create a bitset of ordinals corresponding to the rows in the given key range @@ -168,15 +183,12 @@ private BitsOrPostingList bitsOrPostingListForKeyRange(VectorQueryContext contex { for (long sstableRowId = minSSTableRowId; sstableRowId <= maxSSTableRowId; sstableRowId++) { - if (context.shouldInclude(sstableRowId, primaryKeyMap)) + int segmentRowId = metadata.toSegmentRowId(sstableRowId); + int ordinal = ordinalsView.getOrdinalForRowId(segmentRowId); + if (ordinal >= 0) { - int segmentRowId = metadata.toSegmentRowId(sstableRowId); - int ordinal = ordinalsView.getOrdinalForRowId(segmentRowId); - if (ordinal >= 0) - { - bits.set(ordinal); - hasMatches = true; - } + bits.set(ordinal); + hasMatches = true; } } } @@ -186,12 +198,21 @@ private BitsOrPostingList bitsOrPostingListForKeyRange(VectorQueryContext contex } if (!hasMatches) - return new BitsOrPostingList(PostingList.EMPTY); + return CloseableIterator.empty(); - return new BitsOrPostingList(bits, VectorMemoryIndex.expectedNodesVisited(limit, nRows, graph.size())); + int expectedNodesVisited = expectedNodesVisited(limit, bits.cardinality(), graph.size()); + IntConsumer nodesVisitedConsumer = nodesVisited -> updateExpectedNodes(nodesVisited, expectedNodesVisited); + return graph.search(queryVector, topK, limit, bits, nodesVisitedConsumer); } } + private CloseableIterator searchInternalUnrestricted(float[] queryVector, int limit, int topK) + { + int expectedNodesVisited = expectedNodesVisited(limit, graph.size(), graph.size()); + IntConsumer nodesVisitedConsumer = nodesVisited -> updateExpectedNodes(nodesVisited, expectedNodesVisited); + return graph.search(queryVector, topK, limit, new Bits.MatchAllBits(graph.size()), nodesVisitedConsumer); + } + private long getMaxSSTableRowId(PrimaryKeyMap primaryKeyMap, PartitionPosition right) { // if the right token is the minimum token, there is no upper bound on the keyRange and @@ -212,29 +233,72 @@ private SparseFixedBitSet bitSetForSearch() return bits; } + /** + * Produces a descending score ordered iterator over the rows in the given segment. Branches depending on the number + * of rows to consider and whether the graph has compressed vectors available for faster comparisons. + */ + private CloseableIterator orderByBruteForce(float[] queryVector, SegmentRowIdOrdinalPairs segmentOrdinalPairs, int limit, int topK) throws IOException + { + if (segmentOrdinalPairs.size() == 0) + return CloseableIterator.empty(); + + // If we have more than topK segmentOrdinalPairs, we do a two pass partial sort by first getting the approximate + // similarity score via the PQ vectors that are already in memory and then by hitting disk to get the full + // precision vectors to get the full precision similarity score. + if (graph.getCompressedVectors() != null && segmentOrdinalPairs.size() > topK) + return orderByBruteForceTwoPass(graph.getCompressedVectors(), queryVector, segmentOrdinalPairs, limit, topK); + + try (GraphIndex.View view = graph.getView()) + { + NeighborSimilarity.ExactScoreFunction esf = graph.getExactScoreFunction(queryVector, view); + NeighborQueue scoredRowIds = segmentOrdinalPairs.mapToSegmentRowIdScoreHeap(esf); + return new NeighborQueueRowIdIterator(scoredRowIds); + } + catch (Exception e) + { + throw new IOException(e); + } + } + + /** + * Materialize the compressed vectors for the given segment row ids, put them into a priority queue ordered by + * approximate similarity score, and then pass to the {@link BruteForceRowIdIterator} to lazily resolve the + * full resolution ordering as needed. + */ + private CloseableIterator orderByBruteForceTwoPass(CompressedVectors cv, + float[] queryVector, + SegmentRowIdOrdinalPairs segmentOrdinalPairs, + int limit, + int rerankK) + { + NeighborSimilarity.ApproximateScoreFunction scoreFunction = graph.getApproximateScoreFunction(queryVector); + // Store the index of the (rowId, ordinal) pair from the segmentOrdinalPairs in the NodeQueue so that we can + // retrieve both values with O(1) lookup when we need to resolve the full resolution score in the + // BruteForceRowIdIterator. + NeighborQueue approximateScoreHeap = segmentOrdinalPairs.mapToIndexScoreIterator(scoreFunction); + GraphIndex.View view = graph.getView(); + NeighborSimilarity.ExactScoreFunction esf = graph.getExactScoreFunction(queryVector, view); + return new BruteForceRowIdIterator(approximateScoreHeap, segmentOrdinalPairs, esf, limit, rerankK, view); + } + @Override - public KeyRangeIterator limitToTopKResults(QueryContext context, List primaryKeys, Expression expression) throws IOException + public CloseableIterator orderResultsBy(QueryContext context, List results, Expression orderer) throws IOException { - int limit = context.vectorContext().limit(); + int limit = context.limit(); // VSTODO would it be better to do a binary search to find the boundaries? - List keysInRange = primaryKeys.stream() - .dropWhile(k -> k.compareTo(metadata.minKey) < 0) - .takeWhile(k -> k.compareTo(metadata.maxKey) <= 0) - .collect(Collectors.toList()); + List keysInRange = results.stream() + .dropWhile(k -> k.compareTo(metadata.minKey) < 0) + .takeWhile(k -> k.compareTo(metadata.maxKey) <= 0) + .collect(Collectors.toList()); if (keysInRange.isEmpty()) - return KeyRangeIterator.empty(); - int topK = optimizeFor.topKFor(limit); - if (shouldUseBruteForce(topK, limit, keysInRange.size())) - return new KeyRangeListIterator(metadata.minKey, metadata.maxKey, keysInRange); + return CloseableIterator.empty(); try (PrimaryKeyMap primaryKeyMap = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap()) { // the iterator represents keys from the whole table -- we'll only pull of those that // are from our own token range, so we can use row ids to order the results by vector similarity. - int maxSegmentRowId = metadata.toSegmentRowId(metadata.maxSSTableRowId); - SparseFixedBitSet bits = bitSetForSearch(); - IntArrayList rowIds = new IntArrayList(); - try (var ordinalsView = graph.getOrdinalsView()) + SegmentRowIdOrdinalPairs segmentOrdinalPairs = new SegmentRowIdOrdinalPairs(keysInRange.size()); + try (OnDiskOrdinalsMap.OrdinalsView ordinalsView = graph.getOrdinalsView()) { for (PrimaryKey primaryKey : keysInRange) { @@ -249,46 +313,55 @@ public KeyRangeIterator limitToTopKResults(QueryContext context, List= 0) - bits.set(ordinal); + segmentOrdinalPairs.add(segmentRowId, ordinal); } } - if (shouldUseBruteForce(topK, limit, rowIds.size())) - return toPrimaryKeyIterator(new IntArrayPostingList(rowIds.toIntArray()), context); + int topK = optimizeFor.topKFor(limit); + float[] queryVector = index.termType().decomposeVector(orderer.lower().value.raw.duplicate()); + if (shouldUseBruteForce(topK, limit, segmentOrdinalPairs.size())) + { + return toScoreSortedIterator(orderByBruteForce(queryVector, segmentOrdinalPairs, limit, topK)); + } + + SparseFixedBitSet bits = bitSetForSearch(); + segmentOrdinalPairs.forEachOrdinal(bits::set); // else ask the index to perform a search limited to the bits we created - float[] queryVector = index.termType().decomposeVector(expression.lower().value.raw.duplicate()); - VectorPostingList results = graph.search(queryVector, topK, limit, bits); - updateExpectedNodes(results.getVisitedCount(), expectedNodesVisited(topK, maxSegmentRowId, graph.size())); - return toPrimaryKeyIterator(results, context); + int expectedNodesVisited = expectedNodesVisited(limit, segmentOrdinalPairs.size(), graph.size()); + IntConsumer nodesVisitedConsumer = nodesVisited -> updateExpectedNodes(nodesVisited, expectedNodesVisited); + CloseableIterator result = graph.search(queryVector, topK, limit, bits, nodesVisitedConsumer); + return toScoreSortedIterator(result); } } private boolean shouldUseBruteForce(int topK, int limit, int numRows) { // if we have a small number of results then let TopK processor do exact NN computation - int maxBruteForceRows = min(globalBruteForceRows, maxBruteForceRows(topK, numRows, graph.size())); + int maxBruteForceRows = maxBruteForceRows(topK, numRows, graph.size()); if (logger.isTraceEnabled()) logger.trace("SAI materialized {} rows; max brute force rows is {} for sstable index with {} nodes, LIMIT {}", numRows, maxBruteForceRows, graph.size(), limit); Tracing.trace("SAI materialized {} rows; max brute force rows is {} for sstable index with {} nodes, LIMIT {}", numRows, maxBruteForceRows, graph.size(), limit); - return numRows <= maxBruteForceRows; + return FORCE_BRUTE_FORCE_ANN == null ? numRows <= maxBruteForceRows + : FORCE_BRUTE_FORCE_ANN; } private int maxBruteForceRows(int limit, int nPermittedOrdinals, int graphSize) { - int expectedNodes = expectedNodesVisited(limit, nPermittedOrdinals, graphSize); - // ANN index will do a bunch of extra work besides the full comparisons (performing PQ similarity for each edge); - // brute force from sstable will also do a bunch of extra work (going through trie index to look up row). - // VSTODO I'm not sure which one is more expensive (and it depends on things like sstable chunk cache hit ratio) - // so I'm leaving it as a 1:1 ratio for now. - return max(limit, expectedNodes); + int expectedNodesVisited = expectedNodesVisited(limit, nPermittedOrdinals, graphSize); + int expectedComparisons = index.indexWriterConfig().getMaximumNodeConnections() * expectedNodesVisited; + // in-memory comparisons are cheaper than pulling a row off disk and then comparing + // VSTODO this is dramatically oversimplified + // larger dimension should increase this, because comparisons are more expensive + // lower chunk cache hit ratio should decrease this, because loading rows is more expensive + double memoryToDiskFactor = 0.25; + return (int) max(limit, memoryToDiskFactor * expectedComparisons); } private int expectedNodesVisited(int limit, int nPermittedOrdinals, int graphSize) @@ -318,49 +391,14 @@ public void close() throws IOException graph.close(); } - private static class BitsOrPostingList + private CloseableIterator toScoreSortedIterator(CloseableIterator rowIdIterator) throws IOException { - private final Bits bits; - private final int expectedNodesVisited; - private final PostingList postingList; - - public BitsOrPostingList(@Nullable Bits bits, int expectedNodesVisited) - { - this.bits = bits; - this.expectedNodesVisited = expectedNodesVisited; - this.postingList = null; - } - - public BitsOrPostingList(@Nullable Bits bits) - { - this.bits = bits; - this.postingList = null; - this.expectedNodesVisited = -1; - } - - public BitsOrPostingList(PostingList postingList) + if (!rowIdIterator.hasNext()) { - this.bits = null; - this.postingList = Preconditions.checkNotNull(postingList); - this.expectedNodesVisited = -1; + FileUtils.closeQuietly(rowIdIterator); + return CloseableIterator.empty(); } - @Nullable - public Bits getBits() - { - Preconditions.checkState(!skipANN()); - return bits; - } - - public PostingList postingList() - { - Preconditions.checkState(skipANN()); - return postingList; - } - - public boolean skipANN() - { - return postingList != null; - } + return new RowIdToPrimaryKeyWithScoreIterator(column, primaryKeyMapFactory, rowIdIterator, metadata.rowIdOffset); } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/vector/AutoResumingNodeScoreIterator.java b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/AutoResumingNodeScoreIterator.java new file mode 100644 index 000000000000..5a9f9035f04b --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/AutoResumingNodeScoreIterator.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.vector; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.function.IntConsumer; + +import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.GraphSearcher; +import io.github.jbellis.jvector.graph.NeighborSimilarity; +import io.github.jbellis.jvector.graph.SearchResult; +import io.github.jbellis.jvector.util.Bits; +import io.github.jbellis.jvector.util.GrowableBitSet; +import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.tracing.Tracing; +import org.apache.cassandra.utils.AbstractIterator; + +/** + * An iterator over {@link SearchResult.NodeScore} backed by a {@link SearchResult} that resumes search + * when the backing {@link SearchResult} is exhausted. + */ +public class AutoResumingNodeScoreIterator extends AbstractIterator +{ + private final GraphSearcher searcher; + private final GraphIndex.View view; + private final NeighborSimilarity.ScoreFunction scoreFunction; + private final NeighborSimilarity.ReRanker reRanker; + private final int topK; + private final Bits acceptBits; + private final boolean inMemory; + private final String source; + private final IntConsumer nodesVisitedConsumer; + private Iterator nodeScores = Collections.emptyIterator(); + private int cumulativeNodesVisited; + + // Defer initialization since it is only needed if we need to resume search + private SkipVisitedBits visited = null; + private SearchResult.NodeScore[] previousResult = null; + + /** + * Create a new {@link AutoResumingNodeScoreIterator} that iterates over the provided {@link SearchResult}. + * If the {@link SearchResult} is consumed, it retrieves the next {@link SearchResult} until the search returns + * no more results. + * @param searcher the {@link GraphSearcher} to use to search and resume search. + * @param nodesVisitedConsumer a consumer that accepts the total number of nodes visited + * @param inMemory whether the graph is in memory or on disk (used for trace logging) + * @param source the source of the search (used for trace logging) + * @param view the view used to read from disk. It will be closed when the iterator is closed. + */ + public AutoResumingNodeScoreIterator(GraphSearcher searcher, + NeighborSimilarity.ScoreFunction scoreFunction, + NeighborSimilarity.ReRanker reRanker, + int topK, + Bits acceptBits, + IntConsumer nodesVisitedConsumer, + boolean inMemory, + String source, + GraphIndex.View view) + { + this.searcher = searcher; + this.scoreFunction = scoreFunction; + this.reRanker = reRanker; + this.topK = topK; + this.acceptBits = acceptBits; + + this.cumulativeNodesVisited = 0; + this.nodesVisitedConsumer = nodesVisitedConsumer; + this.inMemory = inMemory; + this.source = source; + this.view = view; + } + + @Override + protected SearchResult.NodeScore computeNext() + { + if (nodeScores.hasNext()) + return nodeScores.next(); + + // Add result from previous search to visited bits + if (previousResult != null) + { + if (visited == null) + visited = new SkipVisitedBits(acceptBits, previousResult.length); + visited.visited(previousResult); + } + Bits bits = visited == null ? acceptBits : visited; + SearchResult nextResult = searcher.search(scoreFunction, reRanker, topK, bits); + + // Record metrics (we add here instead of overwriting because re-queries are expensive proportional to the + // number of visited nodes and even though we throw away some of those results, it helps us determine the + // right path for brute force vs. ANN) + cumulativeNodesVisited += nextResult.getVisitedCount(); + + if (Tracing.isTracing()) + { + Tracing.trace("{} based ANN {} for topK {} visited {} nodes to return {} results from {}", + inMemory ? "Memory" : "Disk", previousResult == null ? "initial" : "re-query", + topK, nextResult.getVisitedCount(), nextResult.getNodes().length, source); + } + + previousResult = nextResult.getNodes(); + // If the next result is empty, we are done searching. + nodeScores = Arrays.stream(nextResult.getNodes()).iterator(); + return nodeScores.hasNext() ? nodeScores.next() : endOfData(); + } + + @Override + public void close() + { + nodesVisitedConsumer.accept(cumulativeNodesVisited); + FileUtils.closeQuietly(view); + } + + private static class SkipVisitedBits implements Bits + { + private final Bits acceptBits; + private final GrowableBitSet visited; + + SkipVisitedBits(Bits acceptBits, int initialBits) + { + this.acceptBits = acceptBits; + this.visited = new GrowableBitSet(initialBits); + } + + void visited(SearchResult.NodeScore[] nodes) + { + for (SearchResult.NodeScore nodeScore : nodes) + visited.set(nodeScore.node); + } + + @Override + public boolean get(int i) + { + return acceptBits.get(i) && !visited.get(i); + } + + @Override + public int length() + { + return acceptBits.length(); + } + } +} \ No newline at end of file diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/vector/BruteForceRowIdIterator.java b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/BruteForceRowIdIterator.java new file mode 100644 index 000000000000..9be94b08e308 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/BruteForceRowIdIterator.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.vector; + +import javax.annotation.concurrent.NotThreadSafe; + +import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.NeighborQueue; +import io.github.jbellis.jvector.graph.NeighborSimilarity; +import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.utils.AbstractIterator; + + +/** + * An iterator over {@link RowIdWithScore} that lazily consumes from a {@link NeighborQueue} of approximate scores. + *

    + * The idea is that we maintain the same level of accuracy as we would get from a graph search, by re-ranking the top + * `k` best approximate scores at a time with the full resolution vectors to return the top `limit`. + *

    + * For example, suppose that limit=3 and k=5 and we have ten elements. After our first re-ranking batch, we have + * ABDEF????? + * We will return A, B, and D; if more elements are requested, we will re-rank another 5 (so three more, including + * the two remaining from the first batch). Here we uncover C, G, and H, and order them appropriately: + * CEFGH?? + * This illustrates that, also like a graph search, we only guarantee ordering of results within a re-ranking batch, + * not globally. + *

    + * Note that we deliberately do not fetch new items from the approximate list until the first batch of `limit`-many + * is consumed. We do this because we expect that most often the first limit-many will pass the final verification + * and only query more if some didn't (e.g. because the vector was deleted in a newer sstable). + *

    + * As an implementation detail, we use a heap to maintain state rather than a List and sorting. + */ +@NotThreadSafe +public class BruteForceRowIdIterator extends AbstractIterator +{ + // We use two binary heaps (NeighborQueue) because we do not need an eager ordering of + // these results. Depending on how many sstables the query hits and the relative scores of vectors from those + // sstables, we may not need to return more than the first handful of scores. + // Heap with compressed vector scores + private final NeighborQueue approximateScoreQueue; + private final SegmentRowIdOrdinalPairs segmentOrdinalPairs; + // Use the jvector NeighborQueue to avoid unnecessary object allocations + private final NeighborQueue exactScoreQueue; + private final NeighborSimilarity.ExactScoreFunction reranker; + private final GraphIndex.View view; + private final int topK; + private final int limit; + private int rerankedCount; + + /** + * @param approximateScoreQueue A heap of indexes ordered by their approximate similarity scores + * @param segmentOrdinalPairs A mapping from the index in the approximateScoreQueue to the node's rowId and ordinal + * @param reranker A function that takes a graph ordinal and returns the exact similarity score + * @param limit The query limit + * @param topK The number of vectors to resolve and score before returning results + * @param view The view of the graph, passed so we can close it when the iterator is closed + */ + public BruteForceRowIdIterator(NeighborQueue approximateScoreQueue, + SegmentRowIdOrdinalPairs segmentOrdinalPairs, + NeighborSimilarity.ExactScoreFunction reranker, + int limit, + int topK, + GraphIndex.View view) + { + this.approximateScoreQueue = approximateScoreQueue; + this.segmentOrdinalPairs = segmentOrdinalPairs; + this.exactScoreQueue = new NeighborQueue(limit, true); + this.reranker = reranker; + assert topK >= limit : "topK must be greater than or equal to limit. Found: " + topK + " < " + limit; + this.limit = limit; + this.topK = topK; + this.rerankedCount = topK; // placeholder to kick off computeNext + this.view = view; + } + + @Override + protected RowIdWithScore computeNext() + { + int consumed = rerankedCount - exactScoreQueue.size(); + if (consumed >= limit) + { + // Refill the exactScoreQueue until it reaches topK exact scores, or the approximate score queue is empty + while (approximateScoreQueue.size() > 0 && exactScoreQueue.size() < topK) + { + int segmentOrdinalIndex = approximateScoreQueue.pop(); + int rowId = segmentOrdinalPairs.getSegmentRowId(segmentOrdinalIndex); + int ordinal = segmentOrdinalPairs.getOrdinal(segmentOrdinalIndex); + float score = reranker.similarityTo(ordinal); + exactScoreQueue.add(rowId, score); + } + rerankedCount = exactScoreQueue.size(); + } + if (exactScoreQueue.size() == 0) + return endOfData(); + + float score = exactScoreQueue.topScore(); + int rowId = exactScoreQueue.pop(); + return new RowIdWithScore(rowId, score); + } + + @Override + public void close() + { + FileUtils.closeQuietly(view); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/vector/DiskAnn.java b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/DiskAnn.java index 196802df4395..a0f3c9c6ab49 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/vector/DiskAnn.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/DiskAnn.java @@ -19,27 +19,23 @@ package org.apache.cassandra.index.sai.disk.v1.vector; import java.io.IOException; -import java.util.Arrays; -import java.util.Iterator; -import java.util.NoSuchElementException; -import java.util.PrimitiveIterator; -import java.util.stream.IntStream; +import java.util.function.IntConsumer; import org.apache.cassandra.index.sai.disk.format.IndexComponent; import org.apache.cassandra.index.sai.disk.v1.IndexWriterConfig; import org.apache.cassandra.index.sai.disk.v1.PerColumnIndexFiles; -import org.apache.cassandra.index.sai.disk.v1.postings.VectorPostingList; import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata; +import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.io.util.FileHandle; -import org.apache.cassandra.tracing.Tracing; +import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.utils.CloseableIterator; +import org.apache.cassandra.utils.Throwables; import io.github.jbellis.jvector.disk.CachingGraphIndex; import io.github.jbellis.jvector.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.GraphIndex; import io.github.jbellis.jvector.graph.GraphSearcher; import io.github.jbellis.jvector.graph.NeighborSimilarity; -import io.github.jbellis.jvector.graph.SearchResult; -import io.github.jbellis.jvector.graph.SearchResult.NodeScore; import io.github.jbellis.jvector.pq.CompressedVectors; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; @@ -50,13 +46,15 @@ public class DiskAnn implements AutoCloseable private final OnDiskOrdinalsMap ordinalsMap; private final CachingGraphIndex graph; private final VectorSimilarityFunction similarityFunction; + private final String source; // only one of these will be not null private final CompressedVectors compressedVectors; - public DiskAnn(SegmentMetadata.ComponentMetadataMap componentMetadatas, PerColumnIndexFiles indexFiles, IndexWriterConfig config) throws IOException + public DiskAnn(SegmentMetadata.ComponentMetadataMap componentMetadatas, PerColumnIndexFiles indexFiles, IndexWriterConfig config, SSTableId sstableId) throws IOException { similarityFunction = config.getSimilarityFunction(); + source = sstableId.toString(); SegmentMetadata.ComponentMetadata termsMetadata = componentMetadatas.get(IndexComponent.TERMS_DATA); graphHandle = indexFiles.termsData(); @@ -87,86 +85,56 @@ public int size() return graph.size(); } + public CompressedVectors getCompressedVectors() + { + return compressedVectors; + } + /** * @return Row IDs associated with the topK vectors near the query */ - public VectorPostingList search(float[] queryVector, int topK, int limit, Bits acceptBits) + public CloseableIterator search(float[] queryVector, int topK, int limit, Bits acceptBits, IntConsumer nodesVisitedConsumer) { OnHeapGraph.validateIndexable(queryVector, similarityFunction); GraphIndex.View view = graph.getView(); - GraphSearcher searcher = new GraphSearcher.Builder<>(view).build(); - NeighborSimilarity.ScoreFunction scoreFunction; - NeighborSimilarity.ReRanker reRanker; - if (compressedVectors == null) + try { - scoreFunction = (NeighborSimilarity.ExactScoreFunction) - i -> similarityFunction.compare(queryVector, view.getVector(i)); - reRanker = null; + GraphSearcher searcher = new GraphSearcher.Builder<>(view).build(); + NeighborSimilarity.ScoreFunction scoreFunction; + NeighborSimilarity.ReRanker reRanker; + if (compressedVectors == null) + { + scoreFunction = (NeighborSimilarity.ExactScoreFunction) + i -> similarityFunction.compare(queryVector, view.getVector(i)); + reRanker = null; + } + else + { + scoreFunction = compressedVectors.approximateScoreFunctionFor(queryVector, similarityFunction); + reRanker = (i, map) -> similarityFunction.compare(queryVector, map.get(i)); + } + Bits acceptedBits = ordinalsMap.ignoringDeleted(acceptBits); + // Search is done within the iterator to keep track of visited nodes. The resulting iterator + // searches until the graph is exhausted. + AutoResumingNodeScoreIterator nodeScoreIterator = new AutoResumingNodeScoreIterator(searcher, scoreFunction, reRanker, topK, acceptedBits, nodesVisitedConsumer, false, source, view); + return new NodeScoreToRowIdWithScoreIterator(nodeScoreIterator, ordinalsMap.getRowIdsView()); } - else + catch (Throwable e) { - scoreFunction = compressedVectors.approximateScoreFunctionFor(queryVector, similarityFunction); - reRanker = (i, map) -> similarityFunction.compare(queryVector, map.get(i)); + FileUtils.closeQuietly(view); + throw Throwables.unchecked(e); } - SearchResult result = searcher.search(scoreFunction, - reRanker, - topK, - ordinalsMap.ignoringDeleted(acceptBits)); - Tracing.trace("DiskANN search visited {} nodes to return {} results", result.getVisitedCount(), result.getNodes().length); - return annRowIdsToPostings(result, limit); } - private class RowIdIterator implements PrimitiveIterator.OfInt, AutoCloseable + public NeighborSimilarity.ApproximateScoreFunction getApproximateScoreFunction(float[] queryVector) { - private final Iterator it; - private final OnDiskOrdinalsMap.RowIdsView rowIdsView = ordinalsMap.getRowIdsView(); - - private OfInt segmentRowIdIterator = IntStream.empty().iterator(); - - public RowIdIterator(NodeScore[] results) - { - this.it = Arrays.stream(results).iterator(); - } - - @Override - public boolean hasNext() - { - while (!segmentRowIdIterator.hasNext() && it.hasNext()) - { - try - { - int ordinal = it.next().node; - segmentRowIdIterator = Arrays.stream(rowIdsView.getSegmentRowIdsMatching(ordinal)).iterator(); - } - catch (IOException e) - { - throw new RuntimeException(e); - } - } - return segmentRowIdIterator.hasNext(); - } - - @Override - public int nextInt() { - if (!hasNext()) - throw new NoSuchElementException(); - return segmentRowIdIterator.nextInt(); - } - - @Override - public void close() - { - rowIdsView.close(); - } + return compressedVectors.approximateScoreFunctionFor(queryVector, similarityFunction); } - private VectorPostingList annRowIdsToPostings(SearchResult results, int limit) + public NeighborSimilarity.ExactScoreFunction getExactScoreFunction(float[] queryVector, GraphIndex.View view) { - try (var iterator = new RowIdIterator(results.getNodes())) - { - return new VectorPostingList(iterator, limit, results.getVisitedCount()); - } + return i -> similarityFunction.compare(queryVector, view.getVector(i)); } @Override @@ -181,4 +149,13 @@ public OnDiskOrdinalsMap.OrdinalsView getOrdinalsView() { return ordinalsMap.getOrdinalsView(); } + + /** + * Get the graph view, callers must close the view. + * @return + */ + public GraphIndex.View getView() + { + return graph.getView(); + } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/vector/NeighborQueueRowIdIterator.java b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/NeighborQueueRowIdIterator.java new file mode 100644 index 000000000000..e234caee531b --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/NeighborQueueRowIdIterator.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.vector; + +import io.github.jbellis.jvector.graph.NeighborQueue; +import org.apache.cassandra.utils.AbstractIterator; + +/** + * An iterator over {@link RowIdWithScore} that lazily consumes a {@link NeighborQueue}. + */ +public class NeighborQueueRowIdIterator extends AbstractIterator +{ + private final NeighborQueue scoreQueue; + + public NeighborQueueRowIdIterator(NeighborQueue scoreQueue) + { + this.scoreQueue = scoreQueue; + } + + @Override + protected RowIdWithScore computeNext() + { + if (scoreQueue.size() == 0) + return endOfData(); + float score = scoreQueue.topScore(); + int rowId = scoreQueue.pop(); + return new RowIdWithScore(rowId, score); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/vector/NodeScoreToRowIdWithScoreIterator.java b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/NodeScoreToRowIdWithScoreIterator.java new file mode 100644 index 000000000000..a23edeaf24d4 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/NodeScoreToRowIdWithScoreIterator.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.vector; + +import java.io.IOException; +import java.util.Arrays; +import java.util.PrimitiveIterator; +import java.util.stream.IntStream; + +import io.github.jbellis.jvector.graph.SearchResult; +import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.utils.AbstractIterator; +import org.apache.cassandra.utils.CloseableIterator; + +/** + * An iterator over {@link RowIdWithScore} sorted by score descending. The iterator converts ordinals (node ids) to + * segment row ids and pairs them with the score given by the index. + */ +public class NodeScoreToRowIdWithScoreIterator extends AbstractIterator +{ + private final CloseableIterator nodeScores; + private final OnDiskOrdinalsMap.RowIdsView rowIdsView; + + private PrimitiveIterator.OfInt segmentRowIdIterator = IntStream.empty().iterator(); + private float currentScore; + + public NodeScoreToRowIdWithScoreIterator(CloseableIterator nodeScores, + OnDiskOrdinalsMap.RowIdsView rowIdsView) + { + this.nodeScores = nodeScores; + this.rowIdsView = rowIdsView; + } + + @Override + protected RowIdWithScore computeNext() + { + try + { + if (segmentRowIdIterator.hasNext()) + return new RowIdWithScore(segmentRowIdIterator.nextInt(), currentScore); + + while (nodeScores.hasNext()) + { + SearchResult.NodeScore result = nodeScores.next(); + currentScore = result.score; + int ordinal = result.node; + segmentRowIdIterator = Arrays.stream(rowIdsView.getSegmentRowIdsMatching(ordinal)).iterator(); + if (segmentRowIdIterator.hasNext()) + return new RowIdWithScore(segmentRowIdIterator.nextInt(), currentScore); + } + return endOfData(); + } + catch (IOException e) + { + throw new RuntimeException(e); + } + } + + @Override + public void close() + { + FileUtils.closeQuietly(rowIdsView); + FileUtils.closeQuietly(nodeScores); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/vector/OnHeapGraph.java b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/OnHeapGraph.java index 369aac2fde70..108752bb4478 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/vector/OnHeapGraph.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/OnHeapGraph.java @@ -24,7 +24,6 @@ import java.util.Collection; import java.util.HashSet; import java.util.Map; -import java.util.PriorityQueue; import java.util.Set; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentSkipListMap; @@ -33,12 +32,14 @@ import java.util.stream.IntStream; import org.apache.lucene.util.StringHelper; +import org.cliffc.high_scale_lib.NonBlockingHashMap; import org.cliffc.high_scale_lib.NonBlockingHashMapLong; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.marshal.VectorType; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.index.sai.disk.format.IndexComponent; import org.apache.cassandra.index.sai.disk.format.IndexDescriptor; @@ -48,8 +49,8 @@ import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata; import org.apache.cassandra.index.sai.utils.IndexIdentifier; import org.apache.cassandra.io.util.SequentialWriter; -import org.apache.cassandra.tracing.Tracing; import org.apache.cassandra.utils.ByteBufferUtil; +import org.apache.cassandra.utils.CloseableIterator; import io.github.jbellis.jvector.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.GraphIndex; @@ -61,6 +62,7 @@ import io.github.jbellis.jvector.pq.CompressedVectors; import io.github.jbellis.jvector.pq.ProductQuantization; import io.github.jbellis.jvector.util.Bits; +import io.github.jbellis.jvector.util.RamUsageEstimator; import io.github.jbellis.jvector.vector.VectorEncoding; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; @@ -68,37 +70,34 @@ public class OnHeapGraph { private static final Logger logger = LoggerFactory.getLogger(OnHeapGraph.class); + public static final int MIN_PQ_ROWS = 1024; + private final RamAwareVectorValues vectorValues; private final GraphIndexBuilder builder; private final VectorType vectorType; private final VectorSimilarityFunction similarityFunction; private final ConcurrentMap> postingsMap; private final NonBlockingHashMapLong> postingsByOrdinal; + private final NonBlockingHashMap vectorsByKey; private final AtomicInteger nextOrdinal = new AtomicInteger(); private volatile boolean hasDeletions; - - /** - * @param termComparator the vector type - * @param indexWriterConfig - * - * Will create a concurrent object. - */ - public OnHeapGraph(AbstractType termComparator, IndexWriterConfig indexWriterConfig) - { - this(termComparator, indexWriterConfig, true); - } + private String source; /** * @param termComparator the vector type * @param indexWriterConfig the {@link IndexWriterConfig} for the graph - * @param concurrent should be true for memtables, false for compaction. Concurrent allows us to search - * while building the graph; non-concurrent allows us to avoid synchronization costs. + * @param memtable should be provided if attached to a memtable, null otherwise (i.e. compaction). Allows us to + * configure concurrent search and provide more meaningful trace logging. Concurrent search + * while building the graph; non-concurrent allows us to avoid synchronization costs. */ @SuppressWarnings("unchecked") - public OnHeapGraph(AbstractType termComparator, IndexWriterConfig indexWriterConfig, boolean concurrent) + public OnHeapGraph(AbstractType termComparator, IndexWriterConfig indexWriterConfig, Memtable memtable) { this.vectorType = (VectorType) termComparator; - vectorValues = concurrent + source = memtable != null + ? memtable.getClass().getSimpleName() + '@' + Integer.toHexString(memtable.hashCode()) + : "compaction"; + vectorValues = memtable != null ? new ConcurrentVectorValues(((VectorType) termComparator).dimension) : new CompactionVectorValues(((VectorType) termComparator)); similarityFunction = indexWriterConfig.getSimilarityFunction(); @@ -108,6 +107,7 @@ public OnHeapGraph(AbstractType termComparator, IndexWriterConfig indexWriter // is thus a better option than hash-based (which has to look at all elements to compute the hash). postingsMap = new ConcurrentSkipListMap<>(Arrays::compare); postingsByOrdinal = new NonBlockingHashMapLong<>(); + vectorsByKey = memtable != null ? new NonBlockingHashMap<>() : null; builder = new GraphIndexBuilder<>(vectorValues, VectorEncoding.FLOAT32, @@ -155,6 +155,16 @@ public long add(ByteBuffer term, T key, InvalidVectorBehavior behavior) } long bytesUsed = 0L; + + // Store a cached reference to the vector for brute force computations later. Because insertions + // for the same primary key are guaranteed to be sequential, there is no race condition here. + if (vectorsByKey != null) + { + vectorsByKey.put(key, vector); + // The size of the entries themselves are counted below, so just count the two extra references + bytesUsed += RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L; + } + VectorPostings postings = postingsMap.get(vector); // if the vector is already in the graph, all that happens is that the postings list is updated // otherwise, we add the vector in this order: @@ -241,6 +251,13 @@ public Collection keysFromOrdinal(int node) return postingsByOrdinal.get(node).getPostings(); } + public float[] vectorForKey(T key) + { + if (vectorsByKey == null) + throw new IllegalStateException("vectorsByKey is not initialized"); + return vectorsByKey.get(key); + } + public long remove(ByteBuffer term, T key) { assert term != null && term.remaining() != 0; @@ -256,31 +273,35 @@ public long remove(ByteBuffer term, T key) } hasDeletions = true; - return postings.remove(key); + long bytesUsed = postings.remove(key); + + if (vectorsByKey != null) + { + // On updates to a row, we call add then remove, so we must pass the key's value to ensure we only remove + // the deleted vector from vectorsByKey. + vectorsByKey.remove(key, vector); + } + + return bytesUsed; } /** * @return keys (PrimaryKey or segment row id) associated with the topK vectors near the query */ - public PriorityQueue search(float[] queryVector, int limit, Bits toAccept) + public CloseableIterator search(float[] queryVector, int limit, Bits toAccept) { validateIndexable(queryVector, similarityFunction); // search() errors out when an empty graph is passed to it if (vectorValues.size() == 0) - return new PriorityQueue<>(); + return CloseableIterator.empty(); Bits bits = hasDeletions ? BitsUtil.bitsIgnoringDeleted(toAccept, postingsByOrdinal) : toAccept; GraphIndex graph = builder.getGraph(); - GraphSearcher searcher = new GraphSearcher.Builder<>(graph.getView()).withConcurrentUpdates().build(); + GraphIndex.View view = graph.getView(); + GraphSearcher searcher = new GraphSearcher.Builder<>(view).withConcurrentUpdates().build(); NeighborSimilarity.ExactScoreFunction scoreFunction = node2 -> vectorCompareFunction(queryVector, node2); - SearchResult result = searcher.search(scoreFunction, null, limit, bits); - Tracing.trace("ANN search visited {} in-memory nodes to return {} results", result.getVisitedCount(), result.getNodes().length); - SearchResult.NodeScore[] a = result.getNodes(); - PriorityQueue keyQueue = new PriorityQueue<>(); - for (int i = 0; i < a.length; i++) - keyQueue.addAll(keysFromOrdinal(a[i].node)); - return keyQueue; + return new AutoResumingNodeScoreIterator(searcher, scoreFunction, null, limit, bits, v -> {}, true, source, view); } public SegmentMetadata.ComponentMetadataMap writeData(IndexDescriptor indexDescriptor, IndexIdentifier indexIdentifier, Function postingTransformer) throws IOException @@ -352,8 +373,8 @@ private long writePQ(SequentialWriter writer) throws IOException { // don't bother with PQ if there are fewer than 1K vectors int M = vectorValues.dimension() / 2; - writer.writeBoolean(vectorValues.size() >= 1024); - if (vectorValues.size() < 1024) + writer.writeBoolean(vectorValues.size() >= MIN_PQ_ROWS); + if (vectorValues.size() < MIN_PQ_ROWS) { logger.debug("Skipping PQ for only {} vectors", vectorValues.size()); return writer.position(); diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/vector/OptimizeFor.java b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/OptimizeFor.java index e2a566ea8ba0..236e9e5b1c6a 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/vector/OptimizeFor.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/OptimizeFor.java @@ -43,7 +43,7 @@ public enum OptimizeFor public int topKFor(int limit) { - return (int)(limitMultiplier.apply(limit) * limit); + return (int)(Math.max(1.0, limitMultiplier.apply(limit)) * limit); } public static OptimizeFor fromString(String value) diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/vector/PrimaryKeyWithScore.java b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/PrimaryKeyWithScore.java new file mode 100644 index 000000000000..123bddc1245e --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/PrimaryKeyWithScore.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.vector; + +import org.apache.cassandra.db.memtable.Memtable; +import org.apache.cassandra.db.rows.Cell; +import org.apache.cassandra.db.rows.Row; +import org.apache.cassandra.index.sai.utils.CellWithSource; +import org.apache.cassandra.index.sai.utils.PrimaryKey; +import org.apache.cassandra.io.sstable.SSTableId; +import org.apache.cassandra.schema.ColumnMetadata; + +/** + * A PrimaryKey with one piece of metadata. Subclasses define the metadata, and to prevent unnecessary boxing, the + * metadata is not referenced in this calss. The metadata is not used to determine equality or hash code, but it is used + * to compare the PrimaryKey objects. + * Note: this class has a natural ordering that is inconsistent with equals. + */ +public class PrimaryKeyWithScore implements Comparable +{ + protected final ColumnMetadata columnMetadata; + private final PrimaryKey primaryKey; + // Either a Memtable reference or an SSTableId reference + private final Object sourceTable; + + private final float indexScore; + + public PrimaryKeyWithScore(ColumnMetadata columnMetadata, Memtable sourceTable, PrimaryKey primaryKey, float indexScore) + { + this.columnMetadata = columnMetadata; + this.sourceTable = sourceTable; + this.primaryKey = primaryKey; + this.indexScore = indexScore; + } + + public PrimaryKeyWithScore(ColumnMetadata columnMetadata, SSTableId sourceTable, PrimaryKey primaryKey, float indexScore) + { + this.columnMetadata = columnMetadata; + this.sourceTable = sourceTable; + this.primaryKey = primaryKey; + this.indexScore = indexScore; + } + + public PrimaryKey primaryKey() + { + return primaryKey; + } + + public boolean isIndexDataValid(Row row, long nowInSecs) + { + // If the indexed column is part of the primary key, we don't need this type of validation because we would have + // fetched the row using the indexed primary key, so they have to match. + if (columnMetadata.isPrimaryKeyColumn()) + return true; + + // If the row is static and the column is not static, or vice versa, the indexed value won't be present so we + // don't need to check if live data matches indexed data. + if (row.isStatic() != columnMetadata.isStatic()) + return true; + + Cell cell = row.getCell(columnMetadata); + if (!cell.isLive(nowInSecs)) + return false; + + assert cell instanceof CellWithSource : "Expected CellWithSource, got " + cell.getClass(); + return sourceTable.equals(((CellWithSource) cell).sourceTable()); + } + + @Override + public int compareTo(PrimaryKeyWithScore o) + { + // Descending order + return Float.compare(o.indexScore, indexScore); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/vector/RowIdToPrimaryKeyWithScoreIterator.java b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/RowIdToPrimaryKeyWithScoreIterator.java new file mode 100644 index 000000000000..b09579e044a6 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/RowIdToPrimaryKeyWithScoreIterator.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.vector; + +import java.io.IOException; + +import org.apache.cassandra.index.sai.disk.PrimaryKeyMap; +import org.apache.cassandra.io.sstable.SSTableId; +import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.schema.ColumnMetadata; +import org.apache.cassandra.utils.AbstractIterator; +import org.apache.cassandra.utils.CloseableIterator; + +/** + * An iterator over scored primary keys ordered by the score descending + * Not skippable. + */ +public class RowIdToPrimaryKeyWithScoreIterator extends AbstractIterator +{ + private final ColumnMetadata column; + private final SSTableId sstableId; + private final PrimaryKeyMap primaryKeyMap; + private final CloseableIterator scoredRowIdIterator; + private final long segmentRowIdOffset; + + public RowIdToPrimaryKeyWithScoreIterator(ColumnMetadata column, + PrimaryKeyMap.Factory primaryKeyMapFactory, + CloseableIterator scoredRowIdIterator, + long segmentRowIdOffset) throws IOException + { + this.column = column; + this.scoredRowIdIterator = scoredRowIdIterator; + this.primaryKeyMap = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap(); + this.sstableId = primaryKeyMap.getSSTableId(); + this.segmentRowIdOffset = segmentRowIdOffset; + } + + @Override + protected PrimaryKeyWithScore computeNext() + { + if (!scoredRowIdIterator.hasNext()) + return endOfData(); + RowIdWithScore rowIdWithScore = scoredRowIdIterator.next(); + return rowIdWithScore.toPrimaryKeyWithScore(column, sstableId, primaryKeyMap, segmentRowIdOffset); + } + + @Override + public void close() + { + FileUtils.closeQuietly(primaryKeyMap); + FileUtils.closeQuietly(scoredRowIdIterator); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/vector/RowIdWithScore.java b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/RowIdWithScore.java new file mode 100644 index 000000000000..310080dd6b16 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/RowIdWithScore.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.vector; + +import org.apache.cassandra.index.sai.disk.PrimaryKeyMap; +import org.apache.cassandra.index.sai.utils.PrimaryKey; +import org.apache.cassandra.io.sstable.SSTableId; +import org.apache.cassandra.schema.ColumnMetadata; + +/** + * Represents a row id with its computed score. + */ +public class RowIdWithScore +{ + private final int segmentRowId; + private final float score; + + public RowIdWithScore(int segmentRowId, float score) + { + this.segmentRowId = segmentRowId; + this.score = score; + } + + public PrimaryKeyWithScore toPrimaryKeyWithScore(ColumnMetadata columnMetadata, + SSTableId sstableId, + PrimaryKeyMap primaryKeyMap, + long segmentRowIdOffset) + { + PrimaryKey pk = primaryKeyMap.primaryKeyFromRowId(segmentRowIdOffset + segmentRowId); + return new PrimaryKeyWithScore(columnMetadata, sstableId, pk, score); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/vector/SegmentRowIdOrdinalPairs.java b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/SegmentRowIdOrdinalPairs.java new file mode 100644 index 000000000000..6335bde14d7b --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/vector/SegmentRowIdOrdinalPairs.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.vector; + +import java.util.function.IntConsumer; + +import io.github.jbellis.jvector.graph.NeighborQueue; +import io.github.jbellis.jvector.graph.NeighborSimilarity; + +/** + * A specialized data structure that stores segment row id to ordinal pairs efficiently. Implemented as an array of int + * pairs that avoids boxing. + */ +public class SegmentRowIdOrdinalPairs +{ + private final int capacity; + private int size; + private final int[] array; + + /** + * Create a new SegmentRowIdOrdinalPairs with the given capacity. + * @param capacity the capacity + */ + public SegmentRowIdOrdinalPairs(int capacity) + { + assert capacity < Integer.MAX_VALUE / 2 : "capacity is too large " + capacity; + this.capacity = capacity; + this.size = 0; + this.array = new int[capacity * 2]; + } + + /** + * Add a pair to the array. + * @param segmentRowId the first value + * @param ordinal the second value + */ + public void add(int segmentRowId, int ordinal) + { + if (size == capacity) + throw new ArrayIndexOutOfBoundsException(size); + array[size * 2] = segmentRowId; + array[size * 2 + 1] = ordinal; + size++; + } + + /** + * Get the row id at the given index. + * @param index the index + * @return the row id + */ + public int getSegmentRowId(int index) + { + if ( index < 0 || index >= size) + throw new ArrayIndexOutOfBoundsException(index); + return array[index * 2]; + } + + /** + * Get the ordinal at the given index. + * @param index the index + * @return the ordinal + */ + public int getOrdinal(int index) + { + if ( index < 0 || index >= size) + throw new ArrayIndexOutOfBoundsException(index); + return array[index * 2 + 1]; + } + + /** + * The number of pairs in the array. + * @return the number of pairs in the array + */ + public int size() + { + return size; + } + + /** + * Create an iterator over the segment row id and scored ordinal pairs in the array. + * @param scoreFunction the score function to use to compute the next score based on the ordinal + * @return a {@link NeighborQueue} + */ + public NeighborQueue mapToSegmentRowIdScoreHeap(NeighborSimilarity.ScoreFunction scoreFunction) + { + // TODO this could be improved using Floyd's algorithm in a later jvector version + NeighborQueue queue = new NeighborQueue(size(), true); + for (int i = 0; i < size; i++) + queue.add(array[i * 2], scoreFunction.similarityTo(array[i * 2 + 1])); // rowid, score + return queue; + } + + /** + * Create an iterator over the index and scored ordinal pairs in the array. + * @param scoreFunction the score function to use to compute the next score based on the ordinal + */ + public NeighborQueue mapToIndexScoreIterator(NeighborSimilarity.ScoreFunction scoreFunction) + { + // TODO this could be improved using Floyd's algorithm in a later jvector version + NeighborQueue queue = new NeighborQueue(size(), true); + for (int i = 0; i < size; i++) + queue.add(i, scoreFunction.similarityTo(array[i * 2 + 1])); // index, score + return queue; + } + + /** + * Calls the consumer for each right value in each pair of the array. + * @param consumer the consumer to call for each right value + */ + public void forEachOrdinal(IntConsumer consumer) + { + for (int i = 0; i < size; i++) + consumer.accept(array[i * 2 + 1]); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/iterators/KeyRangeListIterator.java b/src/java/org/apache/cassandra/index/sai/iterators/KeyRangeListIterator.java deleted file mode 100644 index 4bf5d19bc3c3..000000000000 --- a/src/java/org/apache/cassandra/index/sai/iterators/KeyRangeListIterator.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.index.sai.iterators; - -import java.util.List; - -import com.google.common.collect.Iterators; -import com.google.common.collect.PeekingIterator; - -import org.apache.cassandra.index.sai.utils.PrimaryKey; - -/** - * A {@link KeyRangeIterator} that iterates over a list of {@link PrimaryKey}s without modifying the underlying list. - */ -public class KeyRangeListIterator extends KeyRangeIterator -{ - private final PeekingIterator keyQueue; - - /** - * Create a new {@link KeyRangeListIterator} that iterates over the provided list of keys. - * - * @param minimumKey the minimum key for the provided list of keys - * @param maximumKey the maximum key for the provided list of keys - * @param keys the list of keys to iterate over - */ - public KeyRangeListIterator(PrimaryKey minimumKey, PrimaryKey maximumKey, List keys) - { - super(minimumKey, maximumKey, keys.size()); - this.keyQueue = Iterators.peekingIterator(keys.iterator()); - } - - @Override - protected void performSkipTo(PrimaryKey nextKey) - { - while (keyQueue.hasNext()) - { - if (keyQueue.peek().compareTo(nextKey, false) >= 0) - break; - keyQueue.next(); - } - } - - @Override - public void close() {} - - @Override - protected PrimaryKey computeNext() - { - return keyQueue.hasNext() ? keyQueue.next() : endOfData(); - } -} diff --git a/src/java/org/apache/cassandra/index/sai/iterators/KeyRangeOrderingIterator.java b/src/java/org/apache/cassandra/index/sai/iterators/KeyRangeOrderingIterator.java deleted file mode 100644 index 834fac884760..000000000000 --- a/src/java/org/apache/cassandra/index/sai/iterators/KeyRangeOrderingIterator.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.index.sai.iterators; - -import java.util.ArrayList; -import java.util.List; -import java.util.function.Function; - -import javax.annotation.concurrent.NotThreadSafe; - -import org.apache.cassandra.index.sai.utils.PrimaryKey; -import org.apache.cassandra.io.util.FileUtils; - -/** - * An iterator that consumes a chunk of {@link PrimaryKey}s from the {@link KeyRangeIterator}, passes them to the - * {@link Function} to filter the chunk of {@link PrimaryKey}s and then pass the results to next consumer. - * The PKs are currently returned in {@link PrimaryKey} order, but that contract may change. - */ -@NotThreadSafe -public class KeyRangeOrderingIterator extends KeyRangeIterator -{ - private final KeyRangeIterator input; - private final int chunkSize; - private final Function, KeyRangeIterator> nextRangeFunction; - private final ArrayList nextKeys; - private KeyRangeIterator nextIterator; - - public KeyRangeOrderingIterator(KeyRangeIterator input, int chunkSize, Function, KeyRangeIterator> nextRangeFunction) - { - super(input, () -> {}); - this.input = input; - this.chunkSize = chunkSize; - this.nextRangeFunction = nextRangeFunction; - this.nextKeys = new ArrayList<>(chunkSize); - } - - @Override - public PrimaryKey computeNext() - { - if (nextIterator == null || !nextIterator.hasNext()) - { - do - { - if (!input.hasNext()) - return endOfData(); - nextKeys.clear(); - do - { - nextKeys.add(input.next()); - } - while (nextKeys.size() < chunkSize && input.hasNext()); - // Get the next iterator before closing this one to prevent releasing the resource. - KeyRangeIterator previousIterator = nextIterator; - // If this results in an exception, previousIterator is closed in close() method. - nextIterator = nextRangeFunction.apply(nextKeys); - if (previousIterator != null) - FileUtils.closeQuietly(previousIterator); - // nextIterator might not have any rows due to shadowed primary keys - } - while (!nextIterator.hasNext()); - } - return nextIterator.next(); - } - - @Override - protected void performSkipTo(PrimaryKey nextToken) - { - input.skipTo(nextToken); - if (nextIterator != null) - nextIterator.skipTo(nextToken); - } - - public void close() - { - FileUtils.closeQuietly(input); - FileUtils.closeQuietly(nextIterator); - } -} diff --git a/src/java/org/apache/cassandra/index/sai/iterators/PriorityQueueIterator.java b/src/java/org/apache/cassandra/index/sai/iterators/PriorityQueueIterator.java new file mode 100644 index 000000000000..3da147bd59c4 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/iterators/PriorityQueueIterator.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.iterators; + +import java.util.PriorityQueue; + +import org.apache.cassandra.utils.AbstractIterator; + +/** + * An iterator over a priority queue. + * @param the type of the elements in the priority queue + */ +public class PriorityQueueIterator extends AbstractIterator +{ + private final PriorityQueue queue; + + /** + * Build a PriorityQueueIterator. + * @param queue a priority queue to be lazily consumed by the iterator + */ + public PriorityQueueIterator(PriorityQueue queue) + { + this.queue = queue; + } + + @Override + protected T computeNext() + { + return queue.isEmpty() ? endOfData() : queue.poll(); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java index d592c94209cf..ca46228a3c3b 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java @@ -28,17 +28,19 @@ import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.PartitionPosition; -import org.apache.cassandra.db.marshal.AbstractType; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.dht.AbstractBounds; import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.StorageAttachedIndex; import org.apache.cassandra.index.sai.disk.format.IndexDescriptor; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.utils.IndexIdentifier; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeys; +import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; @@ -47,12 +49,12 @@ public class MemtableIndex implements MemtableOrdering private final MemoryIndex memoryIndex; private final LongAdder writeCount = new LongAdder(); private final LongAdder estimatedMemoryUsed = new LongAdder(); - private final AbstractType type; + private final Memtable memtable; - public MemtableIndex(StorageAttachedIndex index) + public MemtableIndex(StorageAttachedIndex index, Memtable memtable) { - this.memoryIndex = index.termType().isVector() ? new VectorMemoryIndex(index) : new TrieMemoryIndex(index); - this.type = index.termType().indexType(); + this.memoryIndex = index.termType().isVector() ? new VectorMemoryIndex(index, memtable) : new TrieMemoryIndex(index); + this.memtable = memtable; } public long writeCount() @@ -70,6 +72,11 @@ public boolean isEmpty() return memoryIndex.isEmpty(); } + public Memtable getMemtable() + { + return memtable; + } + public ByteBuffer getMinTerm() { return memoryIndex.getMinTerm(); @@ -114,8 +121,14 @@ public SegmentMetadata.ComponentMetadataMap writeDirect(IndexDescriptor indexDes } @Override - public KeyRangeIterator limitToTopResults(List primaryKeys, Expression expression, int limit) + public CloseableIterator orderBy(QueryContext queryContext, Expression orderer, AbstractBounds keyRange) + { + return memoryIndex.orderBy(queryContext, orderer, keyRange); + } + + @Override + public CloseableIterator orderResultsBy(QueryContext queryContext, List results, Expression orderer) { - return memoryIndex.limitToTopResults(primaryKeys, expression, limit); + return memoryIndex.orderResultsBy(queryContext, results, orderer); } } diff --git a/src/java/org/apache/cassandra/index/sai/memory/MemtableIndexManager.java b/src/java/org/apache/cassandra/index/sai/memory/MemtableIndexManager.java index 99bff5a1e1bb..b5128198cb48 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/MemtableIndexManager.java +++ b/src/java/org/apache/cassandra/index/sai/memory/MemtableIndexManager.java @@ -50,15 +50,25 @@ public MemtableIndexManager(StorageAttachedIndex index) this.liveMemtableIndexMap = new ConcurrentHashMap<>(); } - public long index(DecoratedKey key, Row row, Memtable mt) + public void maybeInitializeMemtableIndex(Memtable memtable) + { + if (index.termType().isVector()) + initializeMemtableIndex(memtable); + } + + private MemtableIndex initializeMemtableIndex(Memtable mt) { MemtableIndex current = liveMemtableIndexMap.get(mt); // We expect the relevant IndexMemtable to be present most of the time, so only make the // call to computeIfAbsent() if it's not. (see https://bugs.openjdk.java.net/browse/JDK-8161372) - MemtableIndex target = (current != null) - ? current - : liveMemtableIndexMap.computeIfAbsent(mt, memtable -> new MemtableIndex(index)); + return current != null ? current + : liveMemtableIndexMap.computeIfAbsent(mt, memtable -> new MemtableIndex(index, memtable)); + } + + public long index(DecoratedKey key, Row row, Memtable mt) + { + MemtableIndex target = initializeMemtableIndex(mt); long start = Clock.Global.nanoTime(); @@ -92,9 +102,9 @@ public long update(DecoratedKey key, Row oldRow, Row newRow, Memtable memtable) return index(key, newRow, memtable); } + // Updates should only be able to happen on memtables that were already created and that are still live. MemtableIndex target = liveMemtableIndexMap.get(memtable); - if (target == null) - return 0; + assert target != null : "Memtable for " + memtable.metadata().getTableName() + " not found"; ByteBuffer oldValue = index.termType().valueOf(key, oldRow, FBUtilities.nowInSeconds()); ByteBuffer newValue = index.termType().valueOf(key, newRow, FBUtilities.nowInSeconds()); diff --git a/src/java/org/apache/cassandra/index/sai/memory/MemtableOrdering.java b/src/java/org/apache/cassandra/index/sai/memory/MemtableOrdering.java index bd084a546b55..d437dde74d5a 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/MemtableOrdering.java +++ b/src/java/org/apache/cassandra/index/sai/memory/MemtableOrdering.java @@ -20,9 +20,13 @@ import java.util.List; -import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; +import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.index.sai.QueryContext; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.utils.PrimaryKey; +import org.apache.cassandra.utils.CloseableIterator; /** * Analogue of {@link org.apache.cassandra.index.sai.disk.v1.segment.SegmentOrdering}, but for memtables. @@ -30,13 +34,22 @@ public interface MemtableOrdering { /** - * Filter the given list of {@code PrimaryKey} results to the top `limit` results corresponding to the given expression, - * Returns an iterator over the results that is put back in token order. - *

    - * Assumes that the given list spans the same rows as the implementing index's segment. + * Order the index based on the given orderer (expression). + * + * @param queryContext - the query context + * @param orderer - the expression to order by + * @param keyRange - the key range to search + * @return an iterator over the results in score order. */ - default KeyRangeIterator limitToTopResults(List primaryKeys, Expression expression, int limit) - { - throw new UnsupportedOperationException(); - } + CloseableIterator orderBy(QueryContext queryContext, + Expression orderer, + AbstractBounds keyRange); + + /** + * Order the given list of {@link PrimaryKey} results corresponding to the given orderer. + * Returns an iterator over the results in score order. + * + * Assumes that the given spans the same rows as the implementing index's segment. + */ + CloseableIterator orderResultsBy(QueryContext context, List results, Expression orderer); } diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java index c531ba6696c3..d66eb47e72fc 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java @@ -20,6 +20,7 @@ import java.nio.ByteBuffer; import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.PriorityQueue; import java.util.SortedSet; @@ -42,11 +43,13 @@ import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer; import org.apache.cassandra.index.sai.disk.format.IndexDescriptor; import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.utils.IndexIdentifier; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeys; +import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; @@ -259,6 +262,18 @@ private KeyRangeIterator exactMatch(Expression expression, AbstractBounds orderBy(QueryContext queryContext, Expression orderer, AbstractBounds keyRange) + { + throw new UnsupportedOperationException(); + } + + @Override + public CloseableIterator orderResultsBy(QueryContext context, List results, Expression orderer) + { + throw new UnsupportedOperationException(); + } + private static class Collector { final PriorityQueue mergedKeys; diff --git a/src/java/org/apache/cassandra/index/sai/memory/VectorMemoryIndex.java b/src/java/org/apache/cassandra/index/sai/memory/VectorMemoryIndex.java index 976c81d205a0..ab538df2c29b 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/VectorMemoryIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/VectorMemoryIndex.java @@ -20,7 +20,9 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.NavigableSet; @@ -36,25 +38,31 @@ import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.dht.AbstractBounds; import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.StorageAttachedIndex; -import org.apache.cassandra.index.sai.VectorQueryContext; import org.apache.cassandra.index.sai.disk.format.IndexDescriptor; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; +import org.apache.cassandra.index.sai.iterators.PriorityQueueIterator; import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata; import org.apache.cassandra.index.sai.disk.v1.vector.OnHeapGraph; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; -import org.apache.cassandra.index.sai.iterators.KeyRangeListIterator; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.utils.IndexIdentifier; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeys; import org.apache.cassandra.index.sai.utils.RangeUtil; +import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.tracing.Tracing; +import org.apache.cassandra.utils.AbstractIterator; +import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; +import io.github.jbellis.jvector.graph.SearchResult; import io.github.jbellis.jvector.util.Bits; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import static java.lang.Math.log; import static java.lang.Math.max; @@ -64,6 +72,7 @@ public class VectorMemoryIndex extends MemoryIndex { private final OnHeapGraph graph; + private final Memtable memtable; private final LongAdder writeCount = new LongAdder(); private PrimaryKey minimumKey; @@ -71,10 +80,11 @@ public class VectorMemoryIndex extends MemoryIndex private final NavigableSet primaryKeys = new ConcurrentSkipListSet<>(); - public VectorMemoryIndex(StorageAttachedIndex index) + public VectorMemoryIndex(StorageAttachedIndex index, Memtable memtable) { super(index); - this.graph = new OnHeapGraph<>(index.termType().indexType(), index.indexWriterConfig()); + this.graph = new OnHeapGraph<>(index.termType().indexType(), index.indexWriterConfig(), memtable); + this.memtable = memtable; } @Override @@ -151,11 +161,15 @@ else if (primaryKey.compareTo(maximumKey) > 0) } @Override - public KeyRangeIterator search(QueryContext queryContext, Expression expr, AbstractBounds keyRange) + public KeyRangeIterator search(QueryContext queryContext, Expression expression, AbstractBounds keyRange) { - assert expr.getIndexOperator() == Expression.IndexOperator.ANN : "Only ANN is supported for vector search, received " + expr.getIndexOperator(); + throw new UnsupportedOperationException(); + } - VectorQueryContext vectorQueryContext = queryContext.vectorContext(); + @Override + public CloseableIterator orderBy(QueryContext queryContext, Expression expr, AbstractBounds keyRange) + { + assert expr.getIndexOperator() == Expression.IndexOperator.ANN : "Only ANN is supported for vector search, received " + expr.getIndexOperator(); ByteBuffer buffer = expr.lower().value.raw; float[] qv = index.termType().decomposeVector(buffer); @@ -174,73 +188,92 @@ public KeyRangeIterator search(QueryContext queryContext, Expression expr, Abstr PrimaryKey right = isMaxToken ? null : index.keyFactory().create(keyRange.right.getToken()); // upper bound Set resultKeys = isMaxToken ? primaryKeys.tailSet(left, leftInclusive) : primaryKeys.subSet(left, leftInclusive, right, rightInclusive); - if (!vectorQueryContext.getShadowedPrimaryKeys().isEmpty()) - resultKeys = resultKeys.stream().filter(pk -> !vectorQueryContext.containsShadowedPrimaryKey(pk)).collect(Collectors.toSet()); if (resultKeys.isEmpty()) - return KeyRangeIterator.empty(); + return CloseableIterator.empty(); - int bruteForceRows = maxBruteForceRows(vectorQueryContext.limit(), resultKeys.size(), graph.size()); + int bruteForceRows = maxBruteForceRows(queryContext.limit(), resultKeys.size(), graph.size()); Tracing.trace("Search range covers {} rows; max brute force rows is {} for memtable index with {} nodes, LIMIT {}", - resultKeys.size(), bruteForceRows, graph.size(), vectorQueryContext.limit()); - if (resultKeys.size() < Math.max(vectorQueryContext.limit(), bruteForceRows)) - return new ReorderingRangeIterator(new PriorityQueue<>(resultKeys)); + resultKeys.size(), bruteForceRows, graph.size(), queryContext.limit()); + if (resultKeys.size() < Math.max(queryContext.limit(), bruteForceRows)) + return orderByBruteForce(qv, resultKeys); else - bits = new KeyRangeFilteringBits(keyRange, vectorQueryContext.bitsetForShadowedPrimaryKeys(graph)); + bits = new KeyRangeFilteringBits(keyRange, null); } else { - // partition/range deletion won't trigger index update, so we have to filter shadow primary keys in memtable index - bits = queryContext.vectorContext().bitsetForShadowedPrimaryKeys(graph); + // Accept all bits + bits = new Bits.MatchAllBits(Integer.MAX_VALUE); } - PriorityQueue keyQueue = graph.search(qv, queryContext.vectorContext().limit(), bits); - if (keyQueue.isEmpty()) - return KeyRangeIterator.empty(); - return new ReorderingRangeIterator(keyQueue); + CloseableIterator iterator = graph.search(qv, queryContext.limit(), bits); + return new NodeScoreToScoredPrimaryKeyIterator(iterator); } @Override - public KeyRangeIterator limitToTopResults(List primaryKeys, Expression expression, int limit) + public CloseableIterator orderResultsBy(QueryContext queryContext, List results, Expression orderer) { if (minimumKey == null) // This case implies maximumKey is empty too. - return KeyRangeIterator.empty(); + return CloseableIterator.empty(); - List results = primaryKeys.stream() - .dropWhile(k -> k.compareTo(minimumKey) < 0) - .takeWhile(k -> k.compareTo(maximumKey) <= 0) - .collect(Collectors.toList()); + int limit = queryContext.limit(); - int maxBruteForceRows = maxBruteForceRows(limit, results.size(), graph.size()); + List resultsInRange = results.stream() + .dropWhile(k -> k.compareTo(minimumKey) < 0) + .takeWhile(k -> k.compareTo(maximumKey) <= 0) + .collect(Collectors.toList()); + + int maxBruteForceRows = maxBruteForceRows(limit, resultsInRange.size(), graph.size()); Tracing.trace("SAI materialized {} rows; max brute force rows is {} for memtable index with {} nodes, LIMIT {}", - results.size(), maxBruteForceRows, graph.size(), limit); - if (results.size() <= maxBruteForceRows) - { - if (results.isEmpty()) - return KeyRangeIterator.empty(); - return new KeyRangeListIterator(minimumKey, maximumKey, results); - } + resultsInRange.size(), maxBruteForceRows, graph.size(), limit); - ByteBuffer buffer = expression.lower().value.raw; + if (resultsInRange.isEmpty()) + return CloseableIterator.empty(); + + ByteBuffer buffer = orderer.lower().value.raw; float[] qv = index.termType().decomposeVector(buffer); - KeyFilteringBits bits = new KeyFilteringBits(results); - PriorityQueue keyQueue = graph.search(qv, limit, bits); - if (keyQueue.isEmpty()) - return KeyRangeIterator.empty(); - return new ReorderingRangeIterator(keyQueue); + + if (resultsInRange.size() <= maxBruteForceRows) + return orderByBruteForce(qv, resultsInRange); + + // Search the graph for the topK vectors near the query + KeyFilteringBits bits = new KeyFilteringBits(resultsInRange); + CloseableIterator nodeScores = graph.search(qv, limit, bits); + return new NodeScoreToScoredPrimaryKeyIterator(nodeScores); } private int maxBruteForceRows(int limit, int nPermittedOrdinals, int graphSize) { int expectedNodesVisited = expectedNodesVisited(limit, nPermittedOrdinals, graphSize); - int expectedComparisons = index.indexWriterConfig().getMaximumNodeConnections() * expectedNodesVisited; - // in-memory comparisons are cheaper than pulling a row off disk and then comparing - // VSTODO this is dramatically oversimplified - // larger dimension should increase this, because comparisons are more expensive - // lower chunk cache hit ratio should decrease this, because loading rows is more expensive - double memoryToDiskFactor = 0.25; - return (int) max(limit, memoryToDiskFactor * expectedComparisons); + // ANN index will do a bunch of extra work besides the full comparisons (performing PQ similarity for each edge); + // VSTODO I'm not sure which one is more expensive (and it depends on things like sstable chunk cache hit ratio) + // so I'm leaving it as a 1:1 ratio for now. + return max(limit, expectedNodesVisited); + } + + private CloseableIterator orderByBruteForce(float[] queryVector, Collection keys) + { + VectorSimilarityFunction similarityFunction = index.indexWriterConfig().getSimilarityFunction(); + List scoredKeys = new ArrayList<>(keys.size()); + for (PrimaryKey key : keys) + { + PrimaryKeyWithScore scoredKey = scoreKey(similarityFunction, queryVector, key); + if (scoredKey != null) + scoredKeys.add(scoredKey); + } + // Because we merge iterators from all sstables and memtables, we do not need a complete sort of these + // elements, so a priority queue provides good performance. + return new PriorityQueueIterator<>(new PriorityQueue<>(scoredKeys)); + } + + private PrimaryKeyWithScore scoreKey(VectorSimilarityFunction similarityFunction, float[] queryVector, PrimaryKey key) + { + float[] vector = graph.vectorForKey(key); + if (vector == null) + return null; + float score = similarityFunction.compare(queryVector, vector); + return new PrimaryKeyWithScore(index.termType().columnMetadata(), memtable, key, score); } /** @@ -323,56 +356,67 @@ public int length() } } - private class ReorderingRangeIterator extends KeyRangeIterator + private class KeyFilteringBits implements Bits { - private final PriorityQueue keyQueue; + private final List results; - ReorderingRangeIterator(PriorityQueue keyQueue) + public KeyFilteringBits(List results) { - super(minimumKey, maximumKey, keyQueue.size()); - this.keyQueue = keyQueue; + this.results = results; } @Override - // VSTODO maybe we can abuse "current" to avoid having to pop and re-add the last skipped key - protected void performSkipTo(PrimaryKey nextKey) + public boolean get(int i) { - while (!keyQueue.isEmpty() && keyQueue.peek().compareTo(nextKey) < 0) - keyQueue.poll(); + Collection pk = graph.keysFromOrdinal(i); + return results.stream().anyMatch(pk::contains); } @Override - public void close() {} - - @Override - protected PrimaryKey computeNext() + public int length() { - if (keyQueue.isEmpty()) - return endOfData(); - return keyQueue.poll(); + return results.size(); } } - private class KeyFilteringBits implements Bits + /** + * An iterator over {@link PrimaryKeyWithScore} sorted by score descending. The iterator converts ordinals (node ids) + * to {@link PrimaryKey}s and pairs them with the score given by the index. + */ + private class NodeScoreToScoredPrimaryKeyIterator extends AbstractIterator { - private final List results; + private final CloseableIterator nodeScores; + private Iterator primaryKeysForNode = Collections.emptyIterator(); - public KeyFilteringBits(List results) + NodeScoreToScoredPrimaryKeyIterator(CloseableIterator nodeScores) { - this.results = results; + this.nodeScores = nodeScores; } @Override - public boolean get(int i) + protected PrimaryKeyWithScore computeNext() { - Collection pk = graph.keysFromOrdinal(i); - return results.stream().anyMatch(pk::contains); + if (primaryKeysForNode.hasNext()) + return primaryKeysForNode.next(); + + while (nodeScores.hasNext()) + { + SearchResult.NodeScore nodeScore = nodeScores.next(); + primaryKeysForNode = graph.keysFromOrdinal(nodeScore.node) + .stream() + .map(pk -> new PrimaryKeyWithScore(index.termType().columnMetadata(), memtable, pk, nodeScore.score)) + .iterator(); + if (primaryKeysForNode.hasNext()) + return primaryKeysForNode.next(); + } + + return endOfData(); } @Override - public int length() + public void close() { - return results.size(); + FileUtils.closeQuietly(nodeScores); } } } diff --git a/src/java/org/apache/cassandra/index/sai/plan/Operation.java b/src/java/org/apache/cassandra/index/sai/plan/Operation.java index eb45722af54b..580bf8b6e9bc 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Operation.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Operation.java @@ -26,7 +26,6 @@ import java.util.List; import java.util.Set; import java.util.function.BiFunction; -import java.util.stream.Collectors; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ArrayListMultimap; @@ -42,9 +41,11 @@ import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.StorageAttachedIndex; import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.utils.IndexTermType; import org.apache.cassandra.schema.ColumnMetadata; +import org.apache.cassandra.utils.CloseableIterator; public class Operation { @@ -319,16 +320,23 @@ private static int getPriority(Operator op) */ static KeyRangeIterator buildIterator(QueryController controller) { - List orderings = controller.indexFilter().getExpressions() - .stream().filter(e -> e.operator() == Operator.ANN).collect(Collectors.toList()); - assert orderings.size() <= 1; - if (controller.indexFilter().getExpressions().size() == 1 && orderings.size() == 1) + return Node.buildTree(controller.indexFilter()).analyzeTree(controller).rangeIterator(controller); + } + + /** + * Converts expressions into filter tree for query. + * + * @return a KeyRangeIterator over the index query results + */ + static CloseableIterator buildIteratorForOrder(QueryController controller, QueryViewBuilder.QueryExpressionView view) + { + if (controller.indexFilter().getExpressions().size() == 1) // If we only have one expression, we just use the ANN index to order and limit. - return controller.getTopKRows(orderings.get(0)); - KeyRangeIterator iterator = Node.buildTree(controller.indexFilter()).analyzeTree(controller).rangeIterator(controller); - if (orderings.isEmpty()) - return iterator; - return controller.getTopKRows(iterator, orderings.get(0)); + return controller.getTopKRows(view); + + // Otherwise, we need to search first, then order. + KeyRangeIterator iterator = buildIterator(controller); + return controller.getTopKRows(iterator, view); } /** diff --git a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java index 49e2aeab6d88..b0822e77bb2d 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java +++ b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java @@ -18,21 +18,20 @@ package org.apache.cassandra.index.sai.plan; -import java.io.IOException; -import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.NavigableSet; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; import javax.annotation.Nullable; import com.google.common.collect.Lists; -import org.apache.cassandra.cql3.Operator; +import org.apache.cassandra.config.CassandraRelevantProperties; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.ColumnFamilyStore; import org.apache.cassandra.db.DataRange; @@ -48,30 +47,54 @@ import org.apache.cassandra.db.filter.DataLimits; import org.apache.cassandra.db.filter.RowFilter; import org.apache.cassandra.db.guardrails.Guardrails; +import org.apache.cassandra.db.rows.BaseRowIterator; import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.db.rows.UnfilteredRowIterator; +import org.apache.cassandra.db.transform.Transformation; import org.apache.cassandra.dht.AbstractBounds; import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.StorageAttachedIndex; -import org.apache.cassandra.index.sai.VectorQueryContext; import org.apache.cassandra.index.sai.disk.IndexSearchResultIterator; import org.apache.cassandra.index.sai.disk.SSTableIndex; -import org.apache.cassandra.index.sai.iterators.KeyRangeConcatIterator; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.iterators.KeyRangeIntersectionIterator; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; -import org.apache.cassandra.index.sai.iterators.KeyRangeOrderingIterator; import org.apache.cassandra.index.sai.iterators.KeyRangeUnionIterator; +import org.apache.cassandra.index.sai.memory.MemtableIndex; +import org.apache.cassandra.index.sai.utils.MergePrimaryKeyWithScoreIterator; import org.apache.cassandra.index.sai.utils.PrimaryKey; +import org.apache.cassandra.index.sai.utils.RowWithSource; +import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.net.ParamType; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.tracing.Tracing; +import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.InsertionOrderedNavigableSet; import org.apache.cassandra.utils.Throwables; -import static org.apache.cassandra.config.CassandraRelevantProperties.SAI_VECTOR_SEARCH_ORDER_CHUNK_SIZE; - public class QueryController { + // Transforms a row to include its source table, which is then used for ANN query validation. + private final static Function>> SOURCE_TABLE_ROW_TRANSFORMER = (Object sourceTable) -> new Transformation<>() + { + @Override + protected Row applyToStatic(Row row) + { + return new RowWithSource(row, sourceTable); + } + @Override + protected Row applyToRow(Row row) + { + return new RowWithSource(row, sourceTable); + } + }; + + /** + * The maximum number of primary keys we will materialize when performing hybrid vector search. If this limit is + * exceeded, we switch to an order-by-then-filter execution path + */ + public static int MAX_MATERIALIZED_KEYS = CassandraRelevantProperties.SAI_VECTOR_SEARCH_MAX_MATERIALIZE_KEYS.getInt(); + final QueryContext queryContext; private final ColumnFamilyStore cfs; @@ -82,7 +105,6 @@ public class QueryController private final PrimaryKey.Factory keyFactory; private final PrimaryKey firstPrimaryKey; private final PrimaryKey lastPrimaryKey; - private final int orderChunkSize; private final NavigableSet> nextClusterings; @@ -102,7 +124,6 @@ public QueryController(ColumnFamilyStore cfs, this.keyFactory = new PrimaryKey.Factory(cfs.getPartitioner(), cfs.getComparator()); this.firstPrimaryKey = keyFactory.create(mergeRange.left.getToken()); this.lastPrimaryKey = keyFactory.create(mergeRange.right.getToken()); - this.orderChunkSize = SAI_VECTOR_SEARCH_ORDER_CHUNK_SIZE.getInt(); this.nextClusterings = new InsertionOrderedNavigableSet<>(cfs.metadata().comparator); } @@ -144,6 +165,11 @@ public List dataRanges() return ranges; } + public AbstractBounds mergeRange() + { + return mergeRange; + } + @Nullable public StorageAttachedIndex indexFor(RowFilter.Expression expression) { @@ -188,6 +214,31 @@ public void run() }; } + /** + * Get an iterator over the row(s) for this primary key. Restrict the search to the specified view. Apply the + * {@link #SOURCE_TABLE_ROW_TRANSFORMER} so that resulting cells have the source memtable/sstable. Expect one row + * for a fully qualified primary key or all rows within a partition for a static primary key. + * + * @param key primary key to fetch from storage. + * @param executionController the executionController to use when querying storage + * @return an iterator of rows matching the query + */ + public UnfilteredRowIterator queryStorage(PrimaryKey key, ColumnFamilyStore.ViewFragment view, ReadExecutionController executionController) + { + if (key == null) + throw new IllegalArgumentException("non-null key required"); + + SinglePartitionReadCommand partition = SinglePartitionReadCommand.create(cfs.metadata(), + command.nowInSec(), + command.columnFilter(), + RowFilter.none(), + DataLimits.NONE, + key.partitionKey(), + makeFilter(List.of(key))); + + return partition.queryMemtableAndDisk(cfs, view, SOURCE_TABLE_ROW_TRANSFORMER, executionController); + } + /** * Build a {@link KeyRangeIterator.Builder} from the given list of {@link Expression}s. *

    @@ -216,10 +267,9 @@ public KeyRangeIterator.Builder getIndexQueryResults(Collection expr expressions = expressions.stream().filter(e -> e.getIndexOperator() != Expression.IndexOperator.ANN).collect(Collectors.toList()); QueryViewBuilder.QueryView queryView = new QueryViewBuilder(expressions, mergeRange).build(); - Runnable onClose = getIndexReleaser(queryView.referencedIndexes); KeyRangeIterator.Builder builder = command.rowFilter().isStrict() - ? KeyRangeIntersectionIterator.builder(expressions.size(), onClose) - : KeyRangeUnionIterator.builder(expressions.size(), onClose); + ? KeyRangeIntersectionIterator.builder(expressions.size(), queryView::close) + : KeyRangeUnionIterator.builder(expressions.size(), queryView::close); try { @@ -286,7 +336,7 @@ public KeyRangeIterator.Builder getIndexQueryResults(Collection expr return builder; } - private void maybeTriggerGuardrails(QueryViewBuilder.QueryView queryView) + void maybeTriggerGuardrails(QueryViewBuilder.QueryView queryView) { int referencedIndexes = 0; @@ -330,117 +380,96 @@ public boolean doesNotSelect(PrimaryKey key) } // This is an ANN only query - public KeyRangeIterator getTopKRows(RowFilter.Expression expression) + public CloseableIterator getTopKRows(QueryViewBuilder.QueryExpressionView queryExpressionView) { - assert expression.operator() == Operator.ANN; - StorageAttachedIndex index = indexFor(expression); - assert index != null; - Expression planExpression = Expression.create(index).add(Operator.ANN, expression.getIndexValue().duplicate()); - - QueryViewBuilder.QueryView queryView = new QueryViewBuilder(Collections.singleton(planExpression), mergeRange).build(); - Runnable onClose = getIndexReleaser(queryView.referencedIndexes); - + assert queryExpressionView.expression.operator == Expression.IndexOperator.ANN; + List> intermediateResults = new ArrayList<>(); try { - List memtableResults = queryView.view - .stream() - .map(v -> v.memtableIndexes) - .flatMap(Collection::stream) - .map(idx -> idx.search(queryContext, planExpression, mergeRange)) - .collect(Collectors.toList()); - - List sstableIntersections = queryView.view - .stream() - .map(this::createRowIdIterator) - .collect(Collectors.toList()); - - return IndexSearchResultIterator.build(sstableIntersections, memtableResults, queryView.referencedIndexes, queryContext, onClose); + for (MemtableIndex memtableIndex : queryExpressionView.memtableIndexes) + intermediateResults.add(memtableIndex.orderBy(queryContext, queryExpressionView.expression, mergeRange)); + for (SSTableIndex sstableIndex : queryExpressionView.sstableIndexes) + intermediateResults.addAll(sstableIndex.orderBy(queryExpressionView.expression, mergeRange, queryContext)); + return intermediateResults.isEmpty() ? CloseableIterator.empty() + : new MergePrimaryKeyWithScoreIterator(intermediateResults); } catch (Throwable t) { // all sstable indexes in view have been referenced, need to clean up when exception is thrown - onClose.run(); - throw t; + queryExpressionView.sstableIndexes.forEach(SSTableIndex::releaseQuietly); + intermediateResults.forEach(FileUtils::closeQuietly); + throw Throwables.cleaned(t); } } // This is a hybrid query. We apply all other predicates before ordering and limiting. - public KeyRangeIterator getTopKRows(KeyRangeIterator source, RowFilter.Expression expression) + public CloseableIterator getTopKRows(KeyRangeIterator source, QueryViewBuilder.QueryExpressionView queryExpressionView) { - return new KeyRangeOrderingIterator(source, orderChunkSize, list -> this.getTopKRows(list, expression)); + List primaryKeys = materializeKeysAndCloseSource(source); + if (primaryKeys == null) + return getTopKRows(queryExpressionView); + if (primaryKeys.isEmpty()) + return CloseableIterator.empty(); + return getTopKRows(primaryKeys, queryExpressionView); } - private KeyRangeIterator getTopKRows(List rawSourceKeys, RowFilter.Expression expression) + private CloseableIterator getTopKRows(List sourceKeys, QueryViewBuilder.QueryExpressionView queryExpressionView) { - VectorQueryContext vectorQueryContext = queryContext.vectorContext(); - // Filter out PKs now. Each PK is passed to every segment of the ANN index, so filtering shadowed keys - // eagerly can save some work when going from PK to row id for on disk segments. - // Since the result is shared with multiple streams, we use an unmodifiable list. - List sourceKeys = rawSourceKeys.stream().filter(vectorQueryContext::shouldInclude).collect(Collectors.toList()); - StorageAttachedIndex index = indexFor(expression); - assert index != null : "Cannot do ANN ordering on an unindexed column"; - Expression planExpression = Expression.create(index); - planExpression.add(Operator.ANN, expression.getIndexValue().duplicate()); - - QueryViewBuilder.QueryView queryView = new QueryViewBuilder(Collections.singleton(planExpression), mergeRange).build(); - Runnable onClose = getIndexReleaser(queryView.referencedIndexes); - + List> intermediateResults = new ArrayList<>(); try { - List memtableResults = queryView.view - .stream() - .map(v -> v.memtableIndexes) - .flatMap(Collection::stream) - .map(idx -> idx.limitToTopResults(sourceKeys, planExpression, vectorQueryContext.limit())) - .collect(Collectors.toList()); - - List sstableIntersections = queryView.view - .stream() - .flatMap(pair -> pair.sstableIndexes.stream()) - .map(idx -> { - try - { - return idx.limitToTopKResults(queryContext, sourceKeys, planExpression); - } - catch (IOException e) - { - throw new UncheckedIOException(e); - } - }) - .collect(Collectors.toList()); - - return IndexSearchResultIterator.build(sstableIntersections, memtableResults, queryView.referencedIndexes, queryContext, onClose); + for (MemtableIndex memtableIndex : queryExpressionView.memtableIndexes) + intermediateResults.add(memtableIndex.orderResultsBy(queryContext, sourceKeys, queryExpressionView.expression)); + for (SSTableIndex sstableIndex : queryExpressionView.sstableIndexes) + intermediateResults.addAll(sstableIndex.orderResultsBy(queryContext, sourceKeys, queryExpressionView.expression)); + return intermediateResults.isEmpty() ? CloseableIterator.empty() + : new MergePrimaryKeyWithScoreIterator(intermediateResults); } catch (Throwable t) { // all sstable indexes in view have been referenced, need to clean up when exception is thrown - onClose.run(); - throw t; + queryExpressionView.sstableIndexes.forEach(SSTableIndex::releaseQuietly); + intermediateResults.forEach(FileUtils::closeQuietly); + throw Throwables.cleaned(t); } } /** - * Create row id iterator from different indexes' on-disk searcher of the same sstable + * Materialize the keys from the given source iterator. If there is a meaningful {@link #mergeRange}, the keys + * are filtered to only include those within the range. Note: closes the source iterator. + * @param source The source iterator to fully consume by materializing its keys + * @return The list of materialized keys within the {@link #mergeRange}, or return null if source exceeded the + * materialized keys limit. */ - private KeyRangeIterator createRowIdIterator(QueryViewBuilder.QueryExpressionView indexExpression) + private List materializeKeysAndCloseSource(KeyRangeIterator source) { - List subIterators = indexExpression.sstableIndexes - .stream() - .map(index -> - { - try - { - List iterators = index.search(indexExpression.expression, mergeRange, queryContext); - // concat the result from multiple segments for the same index - return KeyRangeConcatIterator.builder(iterators.size()).add(iterators).build(); - } - catch (Throwable ex) - { - throw Throwables.cleaned(ex); - } - }).collect(Collectors.toList()); - - return KeyRangeUnionIterator.build(subIterators); + try (source) + { + // Skip to the first key (which is really just a token) in the range if it is not the minimum token + if (!mergeRange.left.isMinimum()) + source.skipTo(firstPrimaryKey); + + if (!source.hasNext()) + return List.of(); + + PrimaryKey maxToken = keyFactory.create(mergeRange.right.getToken()); + boolean hasLimitingMaxToken = !maxToken.token().isMinimum() && maxToken.compareTo(source.getMaximum()) < 0; + List primaryKeys = new ArrayList<>(); + int count = 0; + while (source.hasNext()) + { + PrimaryKey next = source.next(); + if (hasLimitingMaxToken && next.compareTo(maxToken) > 0) + break; + primaryKeys.add(next); + if (MAX_MATERIALIZED_KEYS < ++count) + { + Tracing.trace("WHERE clause generated more than {} rows. Switching to ORDER BY then post filter.", MAX_MATERIALIZED_KEYS); + return null; + } + } + return primaryKeys; + } } // Note: This method assumes that the selects method has already been called for the diff --git a/src/java/org/apache/cassandra/index/sai/plan/QueryMaterializesTooManyPrimaryKeysException.java b/src/java/org/apache/cassandra/index/sai/plan/QueryMaterializesTooManyPrimaryKeysException.java new file mode 100644 index 000000000000..cdeea35d4c13 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/plan/QueryMaterializesTooManyPrimaryKeysException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.plan; + +import org.apache.cassandra.db.RejectException; + +public class QueryMaterializesTooManyPrimaryKeysException extends RejectException +{ + public QueryMaterializesTooManyPrimaryKeysException(String msg) + { + super(msg); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/plan/QueryViewBuilder.java b/src/java/org/apache/cassandra/index/sai/plan/QueryViewBuilder.java index cecd45a928af..0d2aeb3838fb 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/QueryViewBuilder.java +++ b/src/java/org/apache/cassandra/index/sai/plan/QueryViewBuilder.java @@ -25,7 +25,9 @@ import java.util.Set; import java.util.stream.Collectors; +import org.apache.cassandra.db.ColumnFamilyStore; import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.dht.AbstractBounds; import org.apache.cassandra.index.sai.disk.SSTableIndex; import org.apache.cassandra.index.sai.memory.MemtableIndex; @@ -64,9 +66,18 @@ public QueryExpressionView(Expression expression, Collection memt this.memtableIndexes = memtableIndexes; this.sstableIndexes = sstableIndexes; } + + public ColumnFamilyStore.ViewFragment computeViewFragment() + { + // Because the SSTableIndex holds a reference to the SSTableReader, we know the sstable is still accessible + // so it is safe to build a view fragment. + List memtables = memtableIndexes.stream().map(MemtableIndex::getMemtable).collect(Collectors.toList()); + List sstableReaders = sstableIndexes.stream().map(SSTableIndex::getSSTable).collect(Collectors.toList()); + return new ColumnFamilyStore.ViewFragment(sstableReaders, memtables); + } } - public static class QueryView + public static class QueryView implements AutoCloseable { public final Collection view; public final Set referencedIndexes; @@ -76,6 +87,12 @@ public QueryView(Collection view, Set referen this.view = view; this.referencedIndexes = referencedIndexes; } + + @Override + public void close() + { + referencedIndexes.forEach(SSTableIndex::releaseQuietly); + } } protected QueryView build() diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java index 77065c09ed56..0e208f8b7b16 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java @@ -151,7 +151,7 @@ public Function postProcessor(ReadCommand return partitions -> partitions; // in case of top-k query, filter out rows that are not actually global top-K - return partitions -> (PartitionIterator) new VectorTopKProcessor(command).filter(partitions); + return partitions -> (PartitionIterator) new VectorTopKProcessor(command).consumeSortByScoreAndTakeTopK(partitions); } /** diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java index cf5beb129bb1..8d84771d0c8b 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java @@ -19,19 +19,25 @@ package org.apache.cassandra.index.sai.plan; import java.nio.ByteBuffer; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.NoSuchElementException; -import java.util.Set; +import java.util.PriorityQueue; +import java.util.Queue; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; +import java.util.stream.Collectors; import javax.annotation.Nonnull; import javax.annotation.Nullable; +import org.apache.cassandra.cql3.Operator; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.ClusteringBound; import org.apache.cassandra.db.ClusteringComparator; @@ -59,13 +65,18 @@ import org.apache.cassandra.exceptions.RequestTimeoutException; import org.apache.cassandra.index.Index; import org.apache.cassandra.index.sai.QueryContext; +import org.apache.cassandra.index.sai.StorageAttachedIndex; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.metrics.TableQueryMetrics; import org.apache.cassandra.index.sai.utils.PrimaryKey; +import org.apache.cassandra.index.sai.utils.RangeUtil; import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.utils.AbstractIterator; import org.apache.cassandra.utils.Clock; +import org.apache.cassandra.utils.CloseableIterator; +import org.apache.cassandra.utils.FBUtilities; import io.netty.util.concurrent.FastThreadLocal; @@ -122,24 +133,42 @@ public PartitionIterator filterReplicaFilteringProtection(PartitionIterator full public UnfilteredPartitionIterator search(ReadExecutionController executionController) throws RequestTimeoutException { if (!command.isTopK()) - return new ResultRetriever(executionController, false); + { + return new ResultRetriever(executionController); + } else { - Supplier resultSupplier = () -> new ResultRetriever(executionController, true); - - // VSTODO performance: if there is shadowed primary keys, we have to at least query twice. - // First time to find out there are shadow keys, second time to find out there are no more shadow keys. - while (true) + // Need a consistent view of the memtables/sstables and their associated index, so we get the view now + // and propagate it as needed. + try (QueryViewBuilder.QueryView queryView = buildAnnQueryView()) { - long lastShadowedKeysCount = queryContext.vectorContext().getShadowedPrimaryKeys().size(); - ResultRetriever result = resultSupplier.get(); - UnfilteredPartitionIterator topK = (UnfilteredPartitionIterator) new VectorTopKProcessor(command).filter(result); + queryController.maybeTriggerGuardrails(queryView); + ScoreOrderedResultRetriever result = new ScoreOrderedResultRetriever(queryController, executionController, queryContext, queryView, command.limits().count()); + // takeTopKThenSortByPrimaryKey eagerly consumes up to k rows from the result because search must + // produce an iterator in PrimaryKey order. + return (UnfilteredPartitionIterator) new VectorTopKProcessor(command).takeTopKThenSortByPrimaryKey(result); + } + } + } - long currentShadowedKeysCount = queryContext.vectorContext().getShadowedPrimaryKeys().size(); - if (lastShadowedKeysCount == currentShadowedKeysCount) - return topK; + private QueryViewBuilder.QueryView buildAnnQueryView() + { + RowFilter.Expression annExpression = null; + for (RowFilter.Expression expression : queryController.indexFilter().getExpressions()) + { + if (expression.operator() == Operator.ANN) + { + if (annExpression != null) + throw new IllegalStateException("Multiple ANN expressions in a single query are not supported"); + annExpression = expression; } } + if (annExpression == null) + throw new IllegalStateException("No ANN expression found in query"); + + StorageAttachedIndex index = queryController.indexFor(annExpression); + Expression planExpression = Expression.create(index).add(Operator.ANN, annExpression.getIndexValue().duplicate()); + return new QueryViewBuilder(Collections.singleton(planExpression), queryController.mergeRange()).build(); } private class ResultRetriever extends AbstractIterator implements UnfilteredPartitionIterator @@ -154,12 +183,11 @@ private class ResultRetriever extends AbstractIterator im private final FilterTree filterTree; private final ReadExecutionController executionController; private final PrimaryKey.Factory keyFactory; - private final boolean topK; private final int partitionRowBatchSize; private PrimaryKey lastKey; - private ResultRetriever(ReadExecutionController executionController, boolean topK) + private ResultRetriever(ReadExecutionController executionController) { this.keyRanges = queryController.dataRanges().iterator(); this.firstDataRange = keyRanges.next(); @@ -170,7 +198,6 @@ private ResultRetriever(ReadExecutionController executionController, boolean top this.keyFactory = queryController.primaryKeyFactory(); this.firstPrimaryKey = queryController.firstPrimaryKeyInRange(); this.lastPrimaryKey = queryController.lastPrimaryKeyInRange(); - this.topK = topK; // Ensure we don't fetch larger batches than the provided LIMIT to avoid fetching keys we won't use: this.partitionRowBatchSize = Math.min(PARTITION_ROW_BATCH_SIZE, command.limits().count()); @@ -460,119 +487,345 @@ private UnfilteredRowIterator queryStorageAndFilter(List keys) queryContext.partitionsRead++; queryContext.checkpoint(); - UnfilteredRowIterator filtered = filterPartition(keys, partition, filterTree); + List filtered = filterPartition(partition, filterTree, queryContext); - // Note that we record the duration of the read after post-filtering, which actually + // Note that we record the duration of the read after post-filtering, which actually // materializes the rows from disk. tableQueryMetrics.postFilteringReadLatency.update(Clock.Global.nanoTime() - startTimeNanos, TimeUnit.NANOSECONDS); - return filtered; + return filtered != null + ? new SinglePartitionIterator(partition, partition.staticRow(), filtered.iterator()) + : null; } } - private UnfilteredRowIterator filterPartition(List keys, UnfilteredRowIterator partition, FilterTree tree) + @Override + public TableMetadata metadata() { - Row staticRow = partition.staticRow(); - DecoratedKey partitionKey = partition.partitionKey(); - List matches = new ArrayList<>(); - boolean hasMatch = false; - Set keysToShadow = topK ? new HashSet<>(keys) : Collections.emptySet(); - - while (partition.hasNext()) - { - Unfiltered unfiltered = partition.next(); - - if (unfiltered.isRow()) - { - queryContext.rowsFiltered++; + return queryController.metadata(); + } - if (tree.isSatisfiedBy(partitionKey, (Row) unfiltered, staticRow)) - { - matches.add(unfiltered); - hasMatch = true; + @Override + public void close() + { + FileUtils.closeQuietly(resultKeyIterator); + if (tableQueryMetrics != null) tableQueryMetrics.record(queryContext); + } + } - if (topK) - { - PrimaryKey shadowed = keyFactory.hasClusteringColumns() - ? keyFactory.create(partitionKey, ((Row) unfiltered).clustering()) - : keyFactory.create(partitionKey); - keysToShadow.remove(shadowed); - } - } - } - } + private static List filterPartition(UnfilteredRowIterator partition, FilterTree tree, QueryContext context) + { + Row staticRow = partition.staticRow(); + DecoratedKey partitionKey = partition.partitionKey(); + List matches = new ArrayList<>(); + boolean hasMatch = false; - // If any non-static rows match the filter, there should be no need to shadow the static primary key: - if (topK && hasMatch && keyFactory.hasClusteringColumns()) - keysToShadow.remove(keyFactory.create(partitionKey, Clustering.STATIC_CLUSTERING)); + while (partition.hasNext()) + { + Unfiltered unfiltered = partition.next(); - // We may not have any non-static row data to filter... - if (!hasMatch) + if (unfiltered.isRow()) { - queryContext.rowsFiltered++; + context.rowsFiltered++; - if (tree.isSatisfiedBy(partitionKey, staticRow, staticRow)) + if (tree.isSatisfiedBy(partitionKey, (Row) unfiltered, staticRow)) { + matches.add((Row) unfiltered); hasMatch = true; - - if (topK) - keysToShadow.clear(); } } + } - if (topK && !keysToShadow.isEmpty()) + // We may not have any non-static row data to filter... + if (!hasMatch) + { + context.rowsFiltered++; + + if (tree.isSatisfiedBy(partitionKey, staticRow, staticRow)) { - // Record primary keys shadowed by expired TTLs, row tombstones, or range tombstones: - queryContext.vectorContext().recordShadowedPrimaryKeys(keysToShadow); + hasMatch = true; } + } + + if (!hasMatch) + { + // If there are no matches, return an empty partition. If reconciliation is required at the + // coordinator, replica filtering protection may make a second round trip to complete its view + // of the partition. + return null; + } + + // Return all matches found + return matches; + } - if (!hasMatch) + private static class SinglePartitionIterator extends AbstractUnfilteredRowIterator + { + private final Iterator rows; + + public SinglePartitionIterator(UnfilteredRowIterator partition, Row staticRow, Iterator rows) + { + super(partition.metadata(), + partition.partitionKey(), + partition.partitionLevelDeletion(), + partition.columns(), + staticRow, + partition.isReverseOrder(), + partition.stats()); + + this.rows = rows; + } + + @Override + protected Unfiltered computeNext() + { + return rows.hasNext() ? rows.next() : endOfData(); + } + } + + /** + * A result retriever that consumes an iterator primary keys sorted by some score, materializes the row for each + * primary key (currently, each primary key is required to be fully qualified and should only point to one row), + * apply the filter tree to the row to test that the real row satisfies the WHERE clause, and finally tests + * that the row is valid for the ORDER BY clause. The class performs some optimizations to avoid materializing + * rows unnecessarily. See the class for more details. + *

    + * The resulting {@link UnfilteredRowIterator} objects are not guaranteed to be in any particular order. It is + * the responsibility of the caller to sort the results if necessary. + */ + public static class ScoreOrderedResultRetriever extends AbstractIterator implements UnfilteredPartitionIterator + { + private final ColumnFamilyStore.ViewFragment view; + private final List> keyRanges; + private final boolean coversFullRing; + private final CloseableIterator scoredPrimaryKeyIterator; + private final FilterTree filterTree; + private final QueryController controller; + private final ReadExecutionController executionController; + private final QueryContext queryContext; + + private final boolean isVectorColumnStatic; + private final HashSet processedKeys; + private final Queue pendingRows; + + // The limit requested by the query. We cannot load more than softLimit rows in bulk because we only want + // to fetch the topk rows where k is the limit. However, we allow the iterator to fetch more rows than the + // soft limit to avoid confusing behavior. When the softLimit is reached, the iterator will fetch one row + // at a time. + private final int softLimit; + private int returnedRowCount = 0; + + private ScoreOrderedResultRetriever(QueryController controller, + ReadExecutionController executionController, + QueryContext queryContext, + QueryViewBuilder.QueryView queryView, + int limit) + { + assert queryView.view.size() == 1; + QueryViewBuilder.QueryExpressionView queryExpressionView = queryView.view.stream().findFirst().get(); + this.view = queryExpressionView.computeViewFragment(); + this.keyRanges = controller.dataRanges().stream().map(DataRange::keyRange).collect(Collectors.toList()); + this.coversFullRing = keyRanges.size() == 1 && RangeUtil.coversFullRing(keyRanges.get(0)); + + this.scoredPrimaryKeyIterator = Operation.buildIteratorForOrder(controller, queryExpressionView); + this.filterTree = Operation.buildFilter(controller, controller.usesStrictFiltering()); + this.controller = controller; + this.executionController = executionController; + this.queryContext = queryContext; + + this.isVectorColumnStatic = queryExpressionView.expression.getIndexTermType().columnMetadata().isStatic(); + this.processedKeys = new HashSet<>(limit); + this.pendingRows = new ArrayDeque<>(limit); + this.softLimit = limit; + } + + @Override + public UnfilteredRowIterator computeNext() + { + if (pendingRows.isEmpty()) + fillPendingRows(); + returnedRowCount++; + // Because we know ordered keys are fully qualified, we do not iterate partitions + return !pendingRows.isEmpty() ? pendingRows.poll() : endOfData(); + } + + /** + * Fills the pendingRows queue to generate a queue of row iterators for the supplied keys by repeatedly calling + * {@link #readAndValidatePartition} until it gives enough non-null results. + */ + private void fillPendingRows() + { + // Group PKs by source sstable/memtable + Map> groupedKeys = new HashMap<>(); + // We always want to get at least 1. When the vector column is static, we cannot batch because we need to + // retain the score ordering a bit longer. + int rowsToRetrieve = isVectorColumnStatic ? 1 : Math.max(1, softLimit - returnedRowCount); + // We want to get the first unique `rowsToRetrieve` keys to materialize + // Don't pass the priority queue here because it is more efficient to add keys in bulk + fillKeys(groupedKeys, rowsToRetrieve, null); + // Sort the primary keys by PrK order, just in case that helps with cache and disk efficiency + PriorityQueue primaryKeyPriorityQueue = new PriorityQueue<>(groupedKeys.keySet()); + + // drain groupedKeys into pendingRows + while (!groupedKeys.isEmpty()) { - // If there are no matches, return an empty partition. If reconciliation is required at the - // coordinator, replica filtering protection may make a second round trip to complete its view - // of the partition. - return null; + PrimaryKey pk = primaryKeyPriorityQueue.poll(); + List sourceKeys = groupedKeys.remove(pk); + UnfilteredRowIterator partitionIterator = readAndValidatePartition(pk, sourceKeys); + if (partitionIterator != null) + pendingRows.add(partitionIterator); + else + // The current primaryKey did not produce a partition iterator. We know the caller will need + // `rowsToRetrieve` rows, so we get the next unique key and add it to the queue. + fillKeys(groupedKeys, 1, primaryKeyPriorityQueue); } + } - // Return all matches found, along with the static row... - return new PartitionIterator(partition, staticRow, matches.iterator()); + /** + * Fills the `groupedKeys` Map with the next `count` unique primary keys that are in the keys produced by calling + * {@link #nextSelectedKeyInRange()}. We map PrimaryKey to a list of PrimaryKeyWithScore because the same + * primary key can be in the result set multiple times, but with different source tables. + * @param groupedKeys the map to fill + * @param count the number of unique PrimaryKeys to consume from the iterator + * @param primaryKeyPriorityQueue the priority queue to add new keys to. If the queue is null, we do not add + * keys to the queue. + */ + private void fillKeys(Map> groupedKeys, int count, PriorityQueue primaryKeyPriorityQueue) + { + int initialSize = groupedKeys.size(); + while (groupedKeys.size() - initialSize < count) + { + PrimaryKeyWithScore primaryKeyWithScore = nextSelectedKeyInRange(); + if (primaryKeyWithScore == null) + return; + PrimaryKey nextPrimaryKey = primaryKeyWithScore.primaryKey(); + List accumulator = groupedKeys.computeIfAbsent(nextPrimaryKey, k -> new ArrayList<>()); + if (primaryKeyPriorityQueue != null && accumulator.isEmpty()) + primaryKeyPriorityQueue.add(nextPrimaryKey); + accumulator.add(primaryKeyWithScore); + } } - private class PartitionIterator extends AbstractUnfilteredRowIterator + /** + * Determine if the key is in one of the queried key ranges. We do not iterate through results in + * {@link PrimaryKey} order, so we have to check each range. + * @param key the key to test + * @return true if the key is in one of the queried key ranges + */ + private boolean isInRange(DecoratedKey key) { - private final Iterator rows; + if (coversFullRing) + return true; + + for (AbstractBounds range : keyRanges) + if (range.contains(key)) + return true; + return false; + } - public PartitionIterator(UnfilteredRowIterator partition, Row staticRow, Iterator rows) + /** + * Returns the next available key contained by one of the keyRanges and selected by the queryController. + * If the next key falls out of the current key range, it skips to the next key range, and so on. + * If no more keys acceptd by the controller are available, returns null. + */ + private @Nullable PrimaryKeyWithScore nextSelectedKeyInRange() + { + while (scoredPrimaryKeyIterator.hasNext()) { - super(partition.metadata(), - partition.partitionKey(), - partition.partitionLevelDeletion(), - partition.columns(), - staticRow, - partition.isReverseOrder(), - partition.stats()); - - this.rows = rows; + PrimaryKeyWithScore key = scoredPrimaryKeyIterator.next(); + if (isInRange(key.primaryKey().partitionKey()) && !controller.doesNotSelect(key.primaryKey())) + return key; } + return null; + } - @Override - protected Unfiltered computeNext() + /** + * Reads and validates a partition for a given primary key against its sources. + *

    + * @param pk The primary key of the partition to read and validate + * @param sourceKeys A list of PrimaryKeyWithScore objects associated with the primary key. + * Multiple sort keys can exist for the same primary key when data comes from different + * sstables or memtables. + * + * @return An UnfilteredRowIterator containing the validated partition data, or null if: + * - The key has already been processed + * - The partition does not pass index filters + * - The partition contains no valid rows + * - The row data does not match the index metadata for any of the provided primary keys + */ + public UnfilteredRowIterator readAndValidatePartition(PrimaryKey pk, List sourceKeys) + { + // If we've already processed the key, we can skip it. Because the score ordered iterator does not + // deduplicate rows, we could see dupes if a row is in the ordering index multiple times. This happens + // in the case of dupes and of overwrites. + if (processedKeys.contains(pk)) + return null; + + try (UnfilteredRowIterator partition = controller.queryStorage(pk, view, executionController)) { - return rows.hasNext() ? rows.next() : endOfData(); + queryContext.partitionsRead++; + queryContext.checkpoint(); + + List clusters = filterPartition(partition, filterTree, queryContext); + + if (clusters == null) + { + // Key counts as processed because the materialized row didn't satisfy the filter logic + processedKeys.add(pk); + return null; + } + + Row staticRow = partition.staticRow(); + long now = FBUtilities.nowInSeconds(); + + // If the pk is static, then we must check that the static row satisfies the source key's validity check. + // Otherwise, we need to make sure that we have one row in the cluster result and then we use that + // for checking validity. + Row representativeRow; + if (pk.kind() == PrimaryKey.Kind.STATIC) + { + representativeRow = staticRow; + } + else + { + if (clusters.isEmpty()) + { + // Key counts as processed because the materialized row didn't satisfy the filter logic + processedKeys.add(pk); + return null; + } + representativeRow = clusters.get(0); + assert clusters.size() == 1 : "Expect 1 result row, but got: " + clusters.size(); + } + + // Each of sourceKeys are equal with respect to primary key equality, but they have different source tables. + // As long as one is valid, we consider the row valid. + for (PrimaryKeyWithScore sourceKey : sourceKeys) + { + assert sourceKey.primaryKey().kind() == pk.kind(); + if (sourceKey.isIndexDataValid(representativeRow, now)) + { + processedKeys.add(pk); + return new SinglePartitionIterator(partition, staticRow, clusters.iterator()); + } + } + // Key does not count as processed because the only thing that "failed" is the validity check on the + // grouped source keys, and it is possible that the score ordered iterator has the same key in the + // iterator lower. We only get here when a vector's value is updated to a more distant vector, so + // the old value ranks high in the iterator, but isn't the current value for the materialized row. + return null; } } @Override public TableMetadata metadata() { - return queryController.metadata(); + return controller.metadata(); } - @Override public void close() { - FileUtils.closeQuietly(resultKeyIterator); - if (tableQueryMetrics != null) tableQueryMetrics.record(queryContext); + FileUtils.closeQuietly(scoredPrimaryKeyIterator); } } diff --git a/src/java/org/apache/cassandra/index/sai/plan/VectorTopKProcessor.java b/src/java/org/apache/cassandra/index/sai/plan/VectorTopKProcessor.java index b21bf38a05d8..9a133c66b868 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/VectorTopKProcessor.java +++ b/src/java/org/apache/cassandra/index/sai/plan/VectorTopKProcessor.java @@ -88,7 +88,7 @@ public VectorTopKProcessor(ReadCommand command) * Filter given partitions and keep the rows with the highest scores. In case of {@link UnfilteredPartitionIterator}, * all tombstones will be kept. */ - public , P extends BasePartitionIterator> BasePartitionIterator filter(P partitions) + public , P extends BasePartitionIterator> BasePartitionIterator consumeSortByScoreAndTakeTopK(P partitions) { // priority queue ordered by score in ascending order PriorityQueue> topK = new PriorityQueue<>(limit + 1, Comparator.comparing(Triple::getRight)); @@ -164,6 +164,53 @@ private float getScoreForRow(DecoratedKey key, Row row) return 0; } + /** + * Filter given partitions and keep the rows with the highest scores. In case of {@link UnfilteredPartitionIterator}, + * all tombstones will be kept. + */ + public , P extends BasePartitionIterator> BasePartitionIterator takeTopKThenSortByPrimaryKey(P partitions) + { + try (partitions) + { + TreeMap> unfilteredByPartition = new TreeMap<>(Comparator.comparing(pi -> pi.key)); + + int rowsMatched = 0; + while (rowsMatched < limit && partitions.hasNext()) + { + try (BaseRowIterator partitionRowIterator = partitions.next()) + { + rowsMatched += processSingleRowPartition(unfilteredByPartition, partitionRowIterator, limit - rowsMatched); + } + } + + return new InMemoryUnfilteredPartitionIterator(command, unfilteredByPartition); + } + } + + /** + * Processes a single partition, without scoring it. + */ + private int processSingleRowPartition(TreeMap> unfilteredByPartition, + BaseRowIterator partitionRowIterator, + int reamining) + { + if (!partitionRowIterator.hasNext()) + return 0; + + // Always include tombstones for coordinator. It relies on ReadCommand#withMetricsRecording to throw + // TombstoneOverwhelmingException to prevent OOM. + PartitionInfo partitionInfo = PartitionInfo.create(partitionRowIterator); + TreeSet map = unfilteredByPartition.computeIfAbsent(partitionInfo, k -> new TreeSet<>(command.metadata().comparator)); + int added = 0; + while (partitionRowIterator.hasNext() && added < reamining) + { + Unfiltered unfiltered = partitionRowIterator.next(); + map.add(unfiltered); + if (unfiltered.isRow()) + added++; + } + return added; + } private Pair findTopKIndex() { diff --git a/src/java/org/apache/cassandra/index/sai/utils/CellWithSource.java b/src/java/org/apache/cassandra/index/sai/utils/CellWithSource.java new file mode 100644 index 000000000000..73316f85e69c --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/utils/CellWithSource.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.utils; + +import java.nio.ByteBuffer; + +import org.apache.cassandra.db.DeletionPurger; +import org.apache.cassandra.db.Digest; +import org.apache.cassandra.db.marshal.ValueAccessor; +import org.apache.cassandra.db.memtable.Memtable; +import org.apache.cassandra.db.rows.Cell; +import org.apache.cassandra.db.rows.CellPath; +import org.apache.cassandra.db.rows.ColumnData; +import org.apache.cassandra.db.rows.ComplexColumnData; +import org.apache.cassandra.io.sstable.SSTableId; +import org.apache.cassandra.schema.ColumnMetadata; +import org.apache.cassandra.utils.ObjectSizes; +import org.apache.cassandra.utils.memory.ByteBufferCloner; + +/** + * A wrapped {@link Cell} that includes a reference to the cell's source table. + * @param the type of the cell's value + */ +public class CellWithSource extends Cell +{ + private static final long EMPTY_SIZE = ObjectSizes.measure(new CellWithSource<>(null, null, null)); + + private final Cell cell; + private final Object source; + + public CellWithSource(Cell cell, Object source) + { + this(cell.column(), cell, source); + assert source instanceof Memtable || source instanceof SSTableId : "Source has unexpected type: " + (source == null ? "null" : source.getClass()); + } + + private CellWithSource(ColumnMetadata column, Cell cell, Object source) + { + super(column); + this.cell = cell; + this.source = source; + } + + public Object sourceTable() + { + return source; + } + + @Override + public boolean isCounterCell() + { + return cell.isCounterCell(); + } + + @Override + public T value() + { + return cell.value(); + } + + @Override + public ValueAccessor accessor() + { + return cell.accessor(); + } + + @Override + public long timestamp() + { + return cell.timestamp(); + } + + @Override + public int ttl() + { + return cell.ttl(); + } + + @Override + public long localDeletionTime() + { + return cell.localDeletionTime(); + } + + @Override + public boolean isTombstone() + { + return cell.isTombstone(); + } + + @Override + public boolean isExpiring() + { + return cell.isExpiring(); + } + + @Override + public boolean isLive(long nowInSec) + { + return cell.isLive(nowInSec); + } + + @Override + public CellPath path() + { + return cell.path(); + } + + @Override + public Cell withUpdatedColumn(ColumnMetadata newColumn) + { + return wrapIfNew(cell.withUpdatedColumn(newColumn)); + } + + @Override + public Cell withUpdatedValue(ByteBuffer newValue) + { + return wrapIfNew(cell.withUpdatedValue(newValue)); + } + + @Override + public Cell withUpdatedTimestampAndLocalDeletionTime(long newTimestamp, long newLocalDeletionTime) + { + return wrapIfNew(cell.withUpdatedTimestampAndLocalDeletionTime(newTimestamp, newLocalDeletionTime)); + } + + @Override + public Cell withSkippedValue() + { + return wrapIfNew(cell.withSkippedValue()); + } + + @Override + public Cell clone(ByteBufferCloner cloner) + { + return wrapIfNew(cell.clone(cloner)); + } + + @Override + public int dataSize() + { + return cell.dataSize(); + } + + @Override + public long unsharedHeapSizeExcludingData() + { + return cell.unsharedHeapSizeExcludingData(); + } + + @Override + public long unsharedHeapSize() + { + return cell.unsharedHeapSize() + EMPTY_SIZE; + } + + @Override + public void validate() + { + cell.validate(); + } + + @Override + public boolean hasInvalidDeletions() + { + return cell.hasInvalidDeletions(); + } + + @Override + public void digest(Digest digest) + { + cell.digest(digest); + } + + @Override + public ColumnData updateAllTimestamp(long newTimestamp) + { + ColumnData maybeNewCell = cell.updateAllTimestamp(newTimestamp); + if (maybeNewCell instanceof Cell) + return wrapIfNew((Cell) maybeNewCell); + if (maybeNewCell instanceof ComplexColumnData) + return ((ComplexColumnData) maybeNewCell).transform(this::wrapIfNew); + // It's not clear when we would hit this code path, but it seems we should not + // hit this from SAI. + throw new IllegalStateException("Expected a Cell instance, but got " + maybeNewCell); + } + + @Override + public Cell markCounterLocalToBeCleared() + { + return wrapIfNew(cell.markCounterLocalToBeCleared()); + } + + @Override + public Cell purge(DeletionPurger purger, long nowInSec) + { + return wrapIfNew(cell.purge(purger, nowInSec)); + } + + @Override + public Cell purgeDataOlderThan(long timestamp) + { + return wrapIfNew(cell.purgeDataOlderThan(timestamp)); + } + + @Override + protected int localDeletionTimeAsUnsignedInt() + { + // Cannot call cell's localDeletionTimeAsUnsignedInt() because it's protected. + throw new UnsupportedOperationException(); + } + + @Override + public long maxTimestamp() + { + return cell.maxTimestamp(); + } + + private Cell wrapIfNew(Cell maybeNewCell) + { + if (maybeNewCell == null) + return null; + // If the cell's method returned a reference to the same cell, then + // we can skip creating a new wrapper. + if (maybeNewCell == this.cell) + return this; + return new CellWithSource<>(maybeNewCell, source); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/utils/MergePrimaryKeyWithScoreIterator.java b/src/java/org/apache/cassandra/index/sai/utils/MergePrimaryKeyWithScoreIterator.java new file mode 100644 index 000000000000..5930d44994aa --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/utils/MergePrimaryKeyWithScoreIterator.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.utils; + +import java.util.Collection; +import java.util.PriorityQueue; + +import com.google.common.collect.Iterators; +import com.google.common.collect.PeekingIterator; + +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; +import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.utils.AbstractIterator; +import org.apache.cassandra.utils.CloseableIterator; + +// TODO: this implementation is sub-optimal due to the combination of PriorityQueue poll/add. A non reducing version of +// the MergeIterator would be better +public class MergePrimaryKeyWithScoreIterator extends AbstractIterator +{ + private final PriorityQueue> queue; + private final Collection> iteratorsToClose; + + public MergePrimaryKeyWithScoreIterator(Collection> iterators) + { + assert !iterators.isEmpty(); + iteratorsToClose = iterators; + queue = new PriorityQueue<>(iterators.size(), (a, b) -> a.peek().compareTo(b.peek())); + for (CloseableIterator iterator : iterators) + { + if (iterator.hasNext()) + queue.add(Iterators.peekingIterator(iterator)); + } + } + + @Override + protected PrimaryKeyWithScore computeNext() + { + if (queue.isEmpty()) + return endOfData(); + + PeekingIterator iterator = queue.poll(); + PrimaryKeyWithScore next = iterator.next(); + if (iterator.hasNext()) + queue.add(iterator); + return next; + } + + @Override + public void close() + { + for (CloseableIterator iterator : iteratorsToClose) + FileUtils.closeQuietly(iterator); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/utils/RowWithSource.java b/src/java/org/apache/cassandra/index/sai/utils/RowWithSource.java new file mode 100644 index 000000000000..d4cdc2267600 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/utils/RowWithSource.java @@ -0,0 +1,390 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.utils; + +import java.util.Collection; +import java.util.Comparator; +import java.util.Iterator; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; + +import com.google.common.collect.Collections2; +import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; + +import org.apache.cassandra.db.Clustering; +import org.apache.cassandra.db.DeletionPurger; +import org.apache.cassandra.db.DeletionTime; +import org.apache.cassandra.db.Digest; +import org.apache.cassandra.db.LivenessInfo; +import org.apache.cassandra.db.filter.ColumnFilter; +import org.apache.cassandra.db.memtable.Memtable; +import org.apache.cassandra.db.rows.Cell; +import org.apache.cassandra.db.rows.CellPath; +import org.apache.cassandra.db.rows.ColumnData; +import org.apache.cassandra.db.rows.ComplexColumnData; +import org.apache.cassandra.db.rows.Row; +import org.apache.cassandra.io.sstable.SSTableId; +import org.apache.cassandra.schema.ColumnMetadata; +import org.apache.cassandra.schema.TableMetadata; +import org.apache.cassandra.utils.BiLongAccumulator; +import org.apache.cassandra.utils.LongAccumulator; +import org.apache.cassandra.utils.ObjectSizes; +import org.apache.cassandra.utils.SearchIterator; +import org.apache.cassandra.utils.memory.Cloner; + +/** + * A Row wrapper that has a source object that gets added to cell as part of the getCell call. This can only be used + * validly when all the cells share a common source object. + */ +public class RowWithSource implements Row +{ + private static final long EMPTY_SIZE = ObjectSizes.measure(new RowWithSource(null, null)); + + private final Row row; + private final Object source; + + public RowWithSource(Row row, Object source) + { + assert source instanceof Memtable || source instanceof SSTableId || (source == null && row == null) : "Expected Memtable or SSTableId, got " + source; + this.row = row; + this.source = source; + } + + @Override + public Kind kind() + { + return row.kind(); + } + + @Override + public Clustering clustering() + { + return row.clustering(); + } + + @Override + public void digest(Digest digest) + { + row.digest(digest); + } + + @Override + public void validateData(TableMetadata metadata) + { + row.validateData(metadata); + } + + @Override + public boolean hasInvalidDeletions() + { + return row.hasInvalidDeletions(); + } + + @Override + public Collection columns() + { + return row.columns(); + } + + @Override + public int columnCount() + { + return row.columnCount(); + } + + @Override + public Deletion deletion() + { + return row.deletion(); + } + + @Override + public LivenessInfo primaryKeyLivenessInfo() + { + return row.primaryKeyLivenessInfo(); + } + + @Override + public boolean isStatic() + { + return row.isStatic(); + } + + @Override + public boolean isEmpty() + { + return row.isEmpty(); + } + + @Override + public String toString(TableMetadata metadata) + { + return row.toString(metadata); + } + + @Override + public boolean hasLiveData(long nowInSec, boolean enforceStrictLiveness) + { + return row.hasLiveData(nowInSec, enforceStrictLiveness); + } + + @Override + public Cell getCell(ColumnMetadata c) + { + Cell cell = row.getCell(c); + if (cell == null) + return null; + return new CellWithSource<>(cell, source); + } + + @Override + public Cell getCell(ColumnMetadata c, CellPath path) + { + return wrapCell(row.getCell(c, path)); + } + + @Override + public ComplexColumnData getComplexColumnData(ColumnMetadata c) + { + return (ComplexColumnData) wrapColumnData(row.getComplexColumnData(c)); + } + + @Override + public ColumnData getColumnData(ColumnMetadata c) + { + return wrapColumnData(row.getColumnData(c)); + } + + @Override + public Iterable> cells() + { + return Iterables.transform(row.cells(), this::wrapCell); + } + + @Override + public Collection columnData() + { + return Collections2.transform(row.columnData(), this::wrapColumnData); + } + + @Override + public Iterable> cellsInLegacyOrder(TableMetadata metadata, boolean reversed) + { + return Iterables.transform(row.cellsInLegacyOrder(metadata, reversed), this::wrapCell); + } + + @Override + public boolean hasComplexDeletion() + { + return row.hasComplexDeletion(); + } + + @Override + public boolean hasComplex() + { + return row.hasComplex(); + } + + @Override + public boolean hasDeletion(long nowInSec) + { + return row.hasDeletion(nowInSec); + } + + @Override + public SearchIterator searchIterator() + { + SearchIterator iterator = row.searchIterator(); + return key -> wrapColumnData(iterator.next(key)); + } + + @Override + public Row filter(ColumnFilter filter, TableMetadata metadata) + { + return maybeWrapRow(row.filter(filter, metadata)); + } + + @Override + public Row filter(ColumnFilter filter, DeletionTime activeDeletion, boolean setActiveDeletionToRow, TableMetadata metadata) + { + return maybeWrapRow(row.filter(filter, activeDeletion, setActiveDeletionToRow, metadata)); + } + + @Override + public Row transformAndFilter(LivenessInfo info, Deletion deletion, Function function) + { + return maybeWrapRow(row.transformAndFilter(info, deletion, function)); + } + + @Override + public Row transformAndFilter(Function function) + { + return maybeWrapRow(row.transformAndFilter(function)); + } + + @Override + public Row clone(Cloner cloner) + { + return maybeWrapRow(row.clone(cloner)); + } + + @Override + public Row purgeDataOlderThan(long timestamp, boolean enforceStrictLiveness) + { + return maybeWrapRow(row.purgeDataOlderThan(timestamp, enforceStrictLiveness)); + } + + @Override + public Row purge(DeletionPurger purger, long nowInSec, boolean enforceStrictLiveness) + { + return maybeWrapRow(row.purge(purger, nowInSec, enforceStrictLiveness)); + } + + @Override + public Row withOnlyQueriedData(ColumnFilter filter) + { + return maybeWrapRow(row.withOnlyQueriedData(filter)); + } + + @Override + public Row markCounterLocalToBeCleared() + { + return maybeWrapRow(row.markCounterLocalToBeCleared()); + } + + @Override + public Row updateAllTimestamp(long newTimestamp) + { + return maybeWrapRow(row.updateAllTimestamp(newTimestamp)); + } + + @Override + public Row withRowDeletion(DeletionTime deletion) + { + return maybeWrapRow(row.withRowDeletion(deletion)); + } + + @Override + public int dataSize() + { + return row.dataSize(); + } + + @Override + public long unsharedHeapSize() + { + return row.unsharedHeapSize() + EMPTY_SIZE; + } + + @Override + public long unsharedHeapSizeExcludingData() + { + return row.unsharedHeapSizeExcludingData() + EMPTY_SIZE; + } + + @Override + public String toString(TableMetadata metadata, boolean fullDetails) + { + return row.toString(metadata, fullDetails); + } + + @Override + public String toString(TableMetadata metadata, boolean includeClusterKeys, boolean fullDetails) + { + return row.toString(metadata, includeClusterKeys, fullDetails); + } + + @Override + public void apply(Consumer function) + { + row.apply(function); + } + + @Override + public void apply(BiConsumer function, A arg) + { + row.apply(function, arg); + } + + @Override + public long accumulate(LongAccumulator accumulator, long initialValue) + { + return row.accumulate(accumulator, initialValue); + } + + @Override + public long accumulate(LongAccumulator accumulator, Comparator comparator, ColumnData from, long initialValue) + { + return row.accumulate(accumulator, comparator, from, initialValue); + } + + @Override + public long accumulate(BiLongAccumulator accumulator, A arg, long initialValue) + { + return row.accumulate(accumulator, arg, initialValue); + } + + @Override + public long accumulate(BiLongAccumulator accumulator, A arg, Comparator comparator, ColumnData from, long initialValue) + { + return row.accumulate(accumulator, arg, comparator, from, initialValue); + } + + @Override + public Iterator iterator() + { + return Iterators.transform(row.iterator(), this::wrapColumnData); + } + + private ColumnData wrapColumnData(ColumnData c) + { + if (c == null) + return null; + if (c instanceof Cell) + return new CellWithSource<>((Cell) c, source); + if (c instanceof ComplexColumnData) + return ((ComplexColumnData) c).transform(c1 -> new CellWithSource<>(c1, source)); + throw new IllegalStateException("Unexpected ColumnData type: " + c.getClass().getName()); + } + + private Cell wrapCell(Cell c) + { + return c != null ? new CellWithSource<>(c, source) : null; + } + + private Row maybeWrapRow(Row r) + { + if (r == null) + return null; + if (r == this.row) + return this; + return new RowWithSource(r, source); + } + + @Override + public String toString() + { + return "RowWithSourceTable{" + + row + + ", source=" + source + + '}'; + } +} diff --git a/src/java/org/apache/cassandra/index/sai/view/IndexViewManager.java b/src/java/org/apache/cassandra/index/sai/view/IndexViewManager.java index 24be0276cf50..76e77a984252 100644 --- a/src/java/org/apache/cassandra/index/sai/view/IndexViewManager.java +++ b/src/java/org/apache/cassandra/index/sai/view/IndexViewManager.java @@ -167,13 +167,6 @@ private Pair, Collection> getBuiltIndex continue; } - if (sstableContext.indexDescriptor.isIndexEmpty(index.termType(), index.identifier())) - { - logger.debug(index.identifier().logMessage("No on-disk index was built for SSTable {} because the SSTable " + - "had no indexable rows for the index."), sstableContext.descriptor()); - continue; - } - try { if (validation != IndexValidation.NONE) @@ -186,7 +179,19 @@ private Pair, Collection> getBuiltIndex } SSTableIndex ssTableIndex = sstableContext.newSSTableIndex(index); - logger.debug(index.identifier().logMessage("Successfully created index for SSTable {}."), sstableContext.descriptor()); + // We used to skip these empty indexes. However, that leads to logically incomplete views of the table, + // so we keep them in the view now. For example, vector indexes use the view to materialize rows, and + // without a complete view, an sstable with no indexable vectors might still have valid data or + // tombstones necessary to ensure proper row materialization. + if (ssTableIndex.getRowCount() == 0) + { + logger.debug(index.identifier().logMessage("No on-disk index was built for SSTable {} because the SSTable " + + "had no indexable rows for the index."), sstableContext.descriptor()); + } + else + { + logger.debug(index.identifier().logMessage("Successfully created index for SSTable {}."), sstableContext.descriptor()); + } // Try to add new index to the set, if set already has such index, we'll simply release and move on. // This covers situation when SSTable collection has the same SSTable multiple diff --git a/src/java/org/apache/cassandra/index/sai/view/View.java b/src/java/org/apache/cassandra/index/sai/view/View.java index 2e30d61422d9..da88059ba1c9 100644 --- a/src/java/org/apache/cassandra/index/sai/view/View.java +++ b/src/java/org/apache/cassandra/index/sai/view/View.java @@ -52,7 +52,9 @@ public View(IndexTermType indexTermType, Collection indexes) for (SSTableIndex sstableIndex : indexes) { this.view.put(sstableIndex.getSSTable().descriptor, sstableIndex); - if (!indexTermType.isVector()) + // Skip vector indexes, since they are scatter gather for all terms. Skip empty indexes since they + // cannot be inserted into the tree due to the lack of min and max terms. + if (!indexTermType.isVector() && sstableIndex.getRowCount() > 0) rangeTermTreeBuilder.add(sstableIndex); } diff --git a/src/java/org/apache/cassandra/index/sai/virtual/SSTableIndexesSystemView.java b/src/java/org/apache/cassandra/index/sai/virtual/SSTableIndexesSystemView.java index fd6f0cd4cc9f..ae466bdf6f09 100644 --- a/src/java/org/apache/cassandra/index/sai/virtual/SSTableIndexesSystemView.java +++ b/src/java/org/apache/cassandra/index/sai/virtual/SSTableIndexesSystemView.java @@ -103,6 +103,12 @@ public DataSet data() for (SSTableIndex sstableIndex : index.view()) { + // Empty indexes are tracked internally for the sake of having complete views. However, + // these indexes have not historically been exposed in this virtual table, so we skip + // them for now. + if (sstableIndex.getRowCount() == 0) + continue; + SSTableReader sstable = sstableIndex.getSSTable(); Descriptor descriptor = sstable.descriptor; AbstractBounds bounds = sstable.getBounds(); diff --git a/src/java/org/apache/cassandra/index/sasi/SASIIndex.java b/src/java/org/apache/cassandra/index/sasi/SASIIndex.java index 7b9ede5b78be..e4e56b542cc0 100644 --- a/src/java/org/apache/cassandra/index/sasi/SASIIndex.java +++ b/src/java/org/apache/cassandra/index/sasi/SASIIndex.java @@ -359,7 +359,7 @@ else if (notification instanceof MemtableRenewedNotification) } else if (notification instanceof MemtableSwitchedNotification) { - index.switchMemtable(((MemtableSwitchedNotification) notification).memtable); + index.switchMemtable(((MemtableSwitchedNotification) notification).previous); } else if (notification instanceof MemtableDiscardedNotification) { diff --git a/src/java/org/apache/cassandra/io/sstable/SSTable.java b/src/java/org/apache/cassandra/io/sstable/SSTable.java index ea3fbbe52a9a..9948a7194c84 100644 --- a/src/java/org/apache/cassandra/io/sstable/SSTable.java +++ b/src/java/org/apache/cassandra/io/sstable/SSTable.java @@ -205,6 +205,11 @@ public String getKeyspaceName() return descriptor.ksname; } + public SSTableId getId() + { + return descriptor.id; + } + public List getAllFilePaths() { List ret = new ArrayList<>(components.size()); diff --git a/src/java/org/apache/cassandra/notifications/MemtableSwitchedNotification.java b/src/java/org/apache/cassandra/notifications/MemtableSwitchedNotification.java index b1737bebfd11..0adf0cef8739 100644 --- a/src/java/org/apache/cassandra/notifications/MemtableSwitchedNotification.java +++ b/src/java/org/apache/cassandra/notifications/MemtableSwitchedNotification.java @@ -21,10 +21,12 @@ public class MemtableSwitchedNotification implements INotification { - public final Memtable memtable; + public final Memtable previous; + public final Memtable next; - public MemtableSwitchedNotification(Memtable switched) + public MemtableSwitchedNotification(Memtable switched, Memtable next) { - this.memtable = switched; + this.previous = switched; + this.next = next; } } diff --git a/test/unit/org/apache/cassandra/cql3/CQLTester.java b/test/unit/org/apache/cassandra/cql3/CQLTester.java index 1dad1a5ae7e3..e4eaec08eab4 100644 --- a/test/unit/org/apache/cassandra/cql3/CQLTester.java +++ b/test/unit/org/apache/cassandra/cql3/CQLTester.java @@ -43,6 +43,7 @@ import java.util.Optional; import java.util.Set; import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; @@ -3253,6 +3254,41 @@ protected List list(Object...values) return Arrays.asList(values); } + /** @return a normalized vector with the given dimension */ + public Vector randomVectorBoxed(int dimension) + { + float[] floats = randomVector(dimension); + return vector(floats); + } + + public float[] randomVector(int dimension) + { + // this can be called from concurrent threads so don't use getRandom() + ThreadLocalRandom R = ThreadLocalRandom.current(); + + float[] vector = new float[dimension]; + for (int i = 0; i < dimension; i++) + { + vector[i] = R.nextFloat(); + } + normalize(vector); + return vector; + } + + /** Normalize the given vector in-place */ + protected static void normalize(float[] v) + { + float sum = 0.0f; + for (int i = 0; i < v.length; i++) + { + sum += v[i] * v[i]; + } + + sum = (float) Math.sqrt(sum); + for (int i = 0; i < v.length; i++) + v[i] /= sum; + } + @SafeVarargs protected final Vector vector(T... values) { diff --git a/test/unit/org/apache/cassandra/db/lifecycle/TrackerTest.java b/test/unit/org/apache/cassandra/db/lifecycle/TrackerTest.java index 90aa9f7ba860..03ac14f364ef 100644 --- a/test/unit/org/apache/cassandra/db/lifecycle/TrackerTest.java +++ b/test/unit/org/apache/cassandra/db/lifecycle/TrackerTest.java @@ -320,7 +320,8 @@ public void testMemtableReplacement() tracker = cfs.getTracker(); listener = new MockListener(false); tracker.subscribe(listener); - prev1 = tracker.switchMemtable(false, cfs.createMemtable(new AtomicReference<>(CommitLog.instance.getCurrentPosition()))); + Memtable next1 = cfs.createMemtable(new AtomicReference<>(CommitLog.instance.getCurrentPosition())); + prev1 = tracker.switchMemtable(false, next1); tracker.markFlushing(prev1); reader = MockSchema.sstable(0, 10, true, cfs); cfs.invalidate(false); @@ -329,7 +330,8 @@ public void testMemtableReplacement() Assert.assertEquals(0, tracker.getView().flushingMemtables.size()); Assert.assertEquals(0, cfs.metric.liveDiskSpaceUsed.getCount()); Assert.assertEquals(5, listener.received.size()); - Assert.assertEquals(prev1, ((MemtableSwitchedNotification) listener.received.get(0)).memtable); + Assert.assertEquals(prev1, ((MemtableSwitchedNotification) listener.received.get(0)).previous); + Assert.assertEquals(next1, ((MemtableSwitchedNotification) listener.received.get(0)).next); Assert.assertEquals(singleton(reader), ((SSTableAddedNotification) listener.received.get(1)).added); Assert.assertEquals(Optional.of(prev1), ((SSTableAddedNotification) listener.received.get(1)).memtable()); Assert.assertEquals(prev1, ((MemtableDiscardedNotification) listener.received.get(2)).memtable); diff --git a/test/unit/org/apache/cassandra/index/sai/SAITester.java b/test/unit/org/apache/cassandra/index/sai/SAITester.java index 00896a850c2b..4672cbaf437e 100644 --- a/test/unit/org/apache/cassandra/index/sai/SAITester.java +++ b/test/unit/org/apache/cassandra/index/sai/SAITester.java @@ -576,6 +576,11 @@ protected void verifySSTableIndexes(IndexIdentifier indexIdentifier, int count) } protected void verifySSTableIndexes(IndexIdentifier indexIdentifier, int sstableContextCount, int sstableIndexCount) + { + verifySSTableIndexes(indexIdentifier, sstableContextCount, sstableIndexCount, 0); + } + + protected void verifySSTableIndexes(IndexIdentifier indexIdentifier, int sstableContextCount, int sstableIndexCount, int expectedEmptyIndexCount) { ColumnFamilyStore cfs = getCurrentColumnFamilyStore(); StorageAttachedIndexGroup indexGroup = getCurrentIndexGroup(); @@ -584,7 +589,10 @@ protected void verifySSTableIndexes(IndexIdentifier indexIdentifier, int sstable StorageAttachedIndex sai = (StorageAttachedIndex) cfs.indexManager.getIndexByName(indexIdentifier.indexName); Collection sstableIndexes = sai == null ? Collections.emptyList() : sai.view().getIndexes(); - assertEquals("Expected " + sstableIndexCount +" SSTableIndexes, but got " + sstableIndexes.toString(), sstableIndexCount, sstableIndexes.size()); + long nonEmptyIndexCount = sstableIndexes.stream().filter(i -> i.getRowCount() > 0).count(); + long emptyIndexCount = sstableIndexes.stream().filter(i -> i.getRowCount() == 0).count(); + assertEquals("Expected " + sstableIndexCount +" SSTableIndexes, but got " + sstableIndexes.toString(), sstableIndexCount, nonEmptyIndexCount); + assertEquals("Expected " + expectedEmptyIndexCount + " empty indexes, but got " + emptyIndexCount, expectedEmptyIndexCount, emptyIndexCount); } protected boolean isBuildCompletionMarker(IndexComponent indexComponent) diff --git a/test/unit/org/apache/cassandra/index/sai/cql/StorageAttachedIndexDDLTest.java b/test/unit/org/apache/cassandra/index/sai/cql/StorageAttachedIndexDDLTest.java index e81869bf3494..69f639238a84 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/StorageAttachedIndexDDLTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/StorageAttachedIndexDDLTest.java @@ -1168,8 +1168,8 @@ private void verifyFlushAndCompactEmptyIndexes(Runnable populateData) IndexTermType numericIndexTermType = createIndexTermType(Int32Type.instance); IndexTermType literalIndexTermType = createIndexTermType(UTF8Type.instance); populateData.run(); - verifySSTableIndexes(numericIndexIdentifier, 2, 0); - verifySSTableIndexes(literalIndexIdentifier, 2, 0); + verifySSTableIndexes(numericIndexIdentifier, 2, 0, 2); + verifySSTableIndexes(literalIndexIdentifier, 2, 0, 2); verifyIndexFiles(numericIndexTermType, numericIndexIdentifier, 2, 0, 2); verifyIndexFiles(literalIndexTermType, literalIndexIdentifier, 2, 0, 2); @@ -1180,8 +1180,8 @@ private void verifyFlushAndCompactEmptyIndexes(Runnable populateData) // compact empty index compact(); - verifySSTableIndexes(numericIndexIdentifier, 1, 0); - verifySSTableIndexes(literalIndexIdentifier, 1, 0); + verifySSTableIndexes(numericIndexIdentifier, 1, 0, 1); + verifySSTableIndexes(literalIndexIdentifier, 1, 0, 1); waitForAssert(() -> verifyIndexFiles(numericIndexTermType, numericIndexIdentifier, 1, 0, 1)); waitForAssert(() -> verifyIndexFiles(literalIndexTermType, literalIndexIdentifier, 1, 0, 1)); diff --git a/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java b/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java index 3cc146ecf6fe..94ab87b4a84a 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java @@ -34,10 +34,11 @@ import org.junit.Test; import org.apache.cassandra.cql3.UntypedResultSet; +import org.apache.cassandra.index.sai.SAITester; import static org.junit.Assert.assertTrue; -public class VectorSiftSmallTest extends VectorTester +public class VectorSiftSmallTest extends SAITester { @Test public void testSiftSmall() throws Throwable @@ -60,6 +61,60 @@ public void testSiftSmall() throws Throwable assertTrue("Disk recall is " + diskRecall, diskRecall > 0.95); } + @Test + public void testSiftSmallWithBooleanPredicatesOfVaryingSelectivity() throws Throwable + { + var siftName = "siftsmall"; + var baseVectors = readFvecs(String.format("test/data/%s/%s_base.fvecs", siftName, siftName)); + var queryVectors = readFvecs(String.format("test/data/%s/%s_query.fvecs", siftName, siftName)); + + // Create table with regular id column and add SAI index on it + createTable("CREATE TABLE %s (pk int, id int, val vector, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(id) USING 'StorageAttachedIndex'"); + + // Insert all vectors into unique partitions with id column + insertVectorsWithId(baseVectors); + + int totalVectors = baseVectors.size(); + int topK = 100; + + // Test with tiny range restriction (1% of data) + int tinyRangeEnd = totalVectors / 100; + // Test with small range restriction (10% of data) + int smallRangeEnd = totalVectors / 10; + // Test with medium range restriction (50% of data) + int mediumRangeEnd = totalVectors / 2; + + // Execute queries of different selectivity. We compute the ground truth for each query vector by brute + // force in this test class because the dataset doesn't exist. + beforeAndAfterFlush(() -> { + double tinyRangeRecall = testRecallWithIdRange(queryVectors, baseVectors, 0, tinyRangeEnd, topK); + assertTrue("Tiny range recall is " + tinyRangeRecall, tinyRangeRecall >= 0.99); + + double smallRangeRecall = testRecallWithIdRange(queryVectors, baseVectors, 0, smallRangeEnd, topK); + assertTrue("Small range recall is " + smallRangeRecall, smallRangeRecall >= 0.975); + + double mediumRangeRecall = testRecallWithIdRange(queryVectors, baseVectors, 0, mediumRangeEnd, topK); + assertTrue("Medium range recall is " + mediumRangeRecall, mediumRangeRecall >= 0.975); + }); + + // Finish by deleting all rows and verifying queries return correctly. + // Using this test because it has many vectors already + disableCompaction(); + for (int i = 0; i < totalVectors; i++) + execute("DELETE FROM %s WHERE pk = ?", i); + + float[] vec = baseVectors.get(0); + beforeAndAfterFlush(() -> { + // Confirm all queries produce 0 rows. These queries hit several edge cases that are otherwise hard to + // test, so we add them at the end of this hybrid sift test. + assertRows(execute("SELECT pk FROM %s WHERE id >= ? AND id <= ? ORDER BY val ANN OF ? LIMIT ?", 0, tinyRangeEnd, vector(vec), 10)); + assertRows(execute("SELECT pk FROM %s WHERE id >= ? AND id <= ? ORDER BY val ANN OF ? LIMIT ?", 0, smallRangeEnd, vector(vec), 10)); + assertRows(execute("SELECT pk FROM %s WHERE id >= ? AND id <= ? ORDER BY val ANN OF ? LIMIT ?", 0, mediumRangeEnd, vector(vec), 10)); + }); + } + public static ArrayList readFvecs(String filePath) throws IOException { var vectors = new ArrayList(); @@ -128,7 +183,7 @@ public double testRecall(List queryVectors, List> grou UntypedResultSet result = execute("SELECT pk FROM %s ORDER BY val ANN OF " + queryVectorAsString + " LIMIT " + topK); var gt = groundTruth.get(i); - int n = (int)result.stream().filter(row -> gt.contains(row.getInt("pk"))).count(); + int n = (int) result.stream().filter(row -> gt.contains(row.getInt("pk"))).count(); topKfound.addAndGet(n); } catch (Throwable throwable) @@ -155,4 +210,88 @@ private void insertVectors(List baseVectors) } }); } + + private void insertVectorsWithId(List baseVectors) + { + IntStream.range(0, baseVectors.size()).parallel().forEach(i -> { + float[] arrayVector = baseVectors.get(i); + String vectorAsString = Arrays.toString(arrayVector); + try + { + execute("INSERT INTO %s " + String.format("(pk, id, val) VALUES (%d, %d, %s)", i, i, vectorAsString)); + } + catch (Throwable throwable) + { + throw new RuntimeException(throwable); + } + }); + } + + private double testRecallWithIdRange(List queryVectors, List baseVectors, + int idStart, int idEnd, int topK) + { + AtomicInteger topKfound = new AtomicInteger(0); + + // Perform query with id range restriction and compute recall + IntStream.range(0, queryVectors.size()).parallel().forEach(i -> { + float[] queryVector = queryVectors.get(i); + String queryVectorAsString = Arrays.toString(queryVector); + + try + { + // Compute ground truth for this filtered range by brute force + var filteredGroundTruth = computeGroundTruthForRange(queryVector, baseVectors, idStart, idEnd, topK); + + String query = String.format("SELECT pk FROM %%s WHERE id >= %d AND id <= %d ORDER BY val ANN OF %s LIMIT %d", + idStart, idEnd, queryVectorAsString, topK); + UntypedResultSet result = execute(query); + + // Count how many results are in the ground truth + int n = (int) result.stream().filter(row -> filteredGroundTruth.contains(row.getInt("pk"))).count(); + topKfound.addAndGet(n); + } + catch (Throwable throwable) + { + throw new RuntimeException(throwable); + } + }); + + return (double) topKfound.get() / (queryVectors.size() * topK); + } + + private HashSet computeGroundTruthForRange(float[] queryVector, List baseVectors, + int idStart, int idEnd, int topK) + { + // Create a list of (id, distance) pairs for vectors in the range + var candidates = new ArrayList>(); + + for (int id = idStart; id <= idEnd && id < baseVectors.size(); id++) + { + float distance = euclideanDistance(queryVector, baseVectors.get(id)); + candidates.add(new java.util.AbstractMap.SimpleEntry<>(id, distance)); + } + + // Sort by distance and take top K + candidates.sort(java.util.Map.Entry.comparingByValue()); + + var groundTruth = new HashSet(); + int limit = Math.min(topK, candidates.size()); + for (int i = 0; i < limit; i++) + { + groundTruth.add(candidates.get(i).getKey()); + } + + return groundTruth; + } + + private float euclideanDistance(float[] a, float[] b) + { + float sum = 0.0f; + for (int i = 0; i < a.length; i++) + { + float diff = a[i] - b[i]; + sum += diff * diff; + } + return (float) Math.sqrt(sum); + } } diff --git a/test/unit/org/apache/cassandra/index/sai/cql/VectorTester.java b/test/unit/org/apache/cassandra/index/sai/cql/VectorTester.java index ab72aabccf57..511831fd96e4 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/VectorTester.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/VectorTester.java @@ -26,19 +26,22 @@ import org.junit.Before; import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.apache.cassandra.index.sai.SAITester; +import org.apache.cassandra.index.sai.disk.v1.segment.VectorIndexSegmentSearcher; import org.apache.cassandra.index.sai.disk.v1.vector.ConcurrentVectorValues; import org.apache.cassandra.index.sai.utils.Glove; -import org.apache.cassandra.inject.ActionBuilder; -import org.apache.cassandra.inject.Injections; -import org.apache.cassandra.inject.InvokePointBuilder; import io.github.jbellis.jvector.graph.GraphIndexBuilder; import io.github.jbellis.jvector.graph.GraphSearcher; import io.github.jbellis.jvector.vector.VectorEncoding; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +@Ignore +@RunWith(Parameterized.class) public class VectorTester extends SAITester { protected static Glove.WordVector word2vec; @@ -49,29 +52,19 @@ public static void loadModel() throws Throwable word2vec = Glove.parse(VectorTester.class.getClassLoader().getResourceAsStream("glove.3K.50d.txt")); } + @Parameterized.Parameter + public Boolean forceBruteForceQueries; + + @Parameterized.Parameters(name = "forceBruteForceQueries={0}") + public static Iterable data() + { + return Arrays.asList(new Object[][]{{true}, {false}, {null}}); + } + @Before public void setup() throws Throwable { - // override maxBruteForceRows to a random number between 0 and 4 so that we make sure - // the non-brute-force path gets called during tests (which mostly involve small numbers of rows) - var n = getRandom().nextIntBetween(0, 4); - var limitToTopResults = InvokePointBuilder.newInvokePoint() - .onClass("org.apache.cassandra.index.sai.disk.v2.V2VectorIndexSearcher") - .onMethod("limitToTopResults") - .atEntry(); - var bitsOrPostingListForKeyRange = InvokePointBuilder.newInvokePoint() - .onClass("org.apache.cassandra.index.sai.disk.v2.V2VectorIndexSearcher") - .onMethod("bitsOrPostingListForKeyRange") - .atEntry(); - var ab = ActionBuilder.newActionBuilder() - .actions() - .doAction("$this.globalBruteForceRows = " + n); - var changeBruteForceThreshold = Injections.newCustom("force_non_bruteforce_queries") - .add(limitToTopResults) - .add(bitsOrPostingListForKeyRange) - .add(ab) - .build(); - Injections.inject(changeBruteForceThreshold); + VectorIndexSegmentSearcher.FORCE_BRUTE_FORCE_ANN = forceBruteForceQueries; } public static double rawIndexedRecall(Collection vectors, float[] query, List result, int topK) throws IOException diff --git a/test/unit/org/apache/cassandra/index/sai/cql/VectorTypeTest.java b/test/unit/org/apache/cassandra/index/sai/cql/VectorTypeTest.java index 94c8b39b5155..ef543ce1d3cf 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/VectorTypeTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/VectorTypeTest.java @@ -35,6 +35,7 @@ import org.apache.cassandra.dht.Token; import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.index.sai.StorageAttachedIndex; +import org.apache.cassandra.index.sai.plan.QueryController; import org.apache.cassandra.service.ClientWarn; import static org.assertj.core.api.Assertions.assertThat; @@ -675,4 +676,48 @@ public void multiPartitionUpdateMultiIndexTest() execute("INSERT INTO %s (pk, metadata, row_v) VALUES (10, {'map_k' : 'map_v'}, [0.11, 0.19])"); assertRows(execute(select), row); } + + @Test + public void testStaticVectorColumnIndex() throws Throwable + { + createTable("CREATE TABLE %s (pk int, ck int, val vector static, PRIMARY KEY(pk, ck))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (pk, ck, val) VALUES (0, 1, [1,0])"); + execute("INSERT INTO %s (pk, ck) VALUES (0, 2)"); + execute("INSERT INTO %s (pk, ck, val) VALUES (1, 3, [0,-1])"); + execute("INSERT INTO %s (pk, ck, val) VALUES (2, 4, [0,1])"); + + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT ck FROM %s ORDER BY val ANN OF [0,1] LIMIT 3"), row(4), row(1), row(2)); + assertRows(execute("SELECT ck FROM %s ORDER BY val ANN OF [0,1] LIMIT 2"), row(4), row(1)); + }); + } + + @Test + public void testTooManyMaterializedKeys() throws Throwable + { + int originalValue = QueryController.MAX_MATERIALIZED_KEYS; + QueryController.MAX_MATERIALIZED_KEYS = 10; + try + { + createTable("CREATE TABLE %s (pk int primary key, i int, val vector)"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(i) USING 'StorageAttachedIndex'"); + + for (int i = 1; i <= QueryController.MAX_MATERIALIZED_KEYS * 10; i++) + execute("INSERT INTO %s (pk, i, val) VALUES (?, ?, [1,0])", i, i); + + beforeAndAfterFlush(() -> { + // Search for less than half of the table, which is over the MAX_MATERIALIZED_KEYS value to trigger + // the switched order by then filter query execution. + UntypedResultSet rows = execute("SELECT pk FROM %s WHERE i < ? ORDER BY val ANN OF [0,1] LIMIT 3", QueryController.MAX_MATERIALIZED_KEYS * 2); + assertRowCount(rows, 3); + }); + } + finally + { + QueryController.MAX_MATERIALIZED_KEYS = originalValue; + } + } } diff --git a/test/unit/org/apache/cassandra/index/sai/cql/VectorUpdateDeleteTest.java b/test/unit/org/apache/cassandra/index/sai/cql/VectorUpdateDeleteTest.java index 31bf3dfb7cb0..66de0c39a56f 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/VectorUpdateDeleteTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/VectorUpdateDeleteTest.java @@ -21,9 +21,10 @@ import org.junit.Test; import org.apache.cassandra.cql3.UntypedResultSet; +import org.apache.cassandra.index.sai.utils.IndexIdentifier; -import static org.apache.cassandra.config.CassandraRelevantProperties.SAI_VECTOR_SEARCH_ORDER_CHUNK_SIZE; import static org.apache.cassandra.index.sai.cql.VectorTypeTest.assertContainsInt; +import static org.apache.cassandra.index.sai.disk.v1.vector.OnHeapGraph.MIN_PQ_ROWS; import static org.assertj.core.api.Assertions.assertThat; public class VectorUpdateDeleteTest extends VectorTester @@ -82,6 +83,21 @@ public void rowDeleteVectorInMemoryAndFlushTest() assertContainsInt(result, "pk", 0); } + @Test + public void testFlushWithDeletedVectors() + { + createTable("CREATE TABLE %s (pk int, v vector, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (pk, v) VALUES (0, [1.0, 2.0])"); + execute("INSERT INTO %s (pk, v) VALUES (0, null)"); + + flush(); + + UntypedResultSet result = execute("SELECT * FROM %s ORDER BY v ann of [2.5, 3.5] LIMIT 1"); + assertThat(result).hasSize(0); + } + // range delete won't trigger UpdateTransaction#onUpdated @Test public void rangeDeleteVectorInMemoryAndFlushTest() @@ -354,7 +370,7 @@ public void updateOtherColumnsTest() execute("INSERT INTO %s (pk, str_val, val) VALUES (1, 'B', [2.0, 3.0, 4.0])"); execute("UPDATE %s SET str_val='C' WHERE pk=0"); - var result = execute("SELECT * FROM %s ORDER BY val ann of [0.5, 1.5, 2.5] LIMIT 2"); + UntypedResultSet result = execute("SELECT * FROM %s ORDER BY val ann of [0.5, 1.5, 2.5] LIMIT 2"); assertThat(result).hasSize(2); } @@ -387,7 +403,7 @@ public void updateManySSTablesTest() execute("INSERT INTO %s (pk, str_val, val) VALUES (1, 'B', [2.0, 3.0, 4.0])"); flush(); - var result = execute("SELECT * FROM %s ORDER BY val ann of [9.5, 10.5, 11.5] LIMIT 1"); + UntypedResultSet result = execute("SELECT * FROM %s ORDER BY val ann of [9.5, 10.5, 11.5] LIMIT 1"); assertThat(result).hasSize(1); assertContainsInt(result, "pk", 0); result = execute("SELECT * FROM %s ORDER BY val ann of [0.5, 1.5, 2.5] LIMIT 1"); @@ -415,10 +431,36 @@ public void shadowedPrimaryKeyInDifferentSSTable() flush(); // the shadow vector has the highest score - var result = execute("SELECT * FROM %s ORDER BY val ann of [1.0, 2.0, 3.0] LIMIT 1"); + UntypedResultSet result = execute("SELECT * FROM %s ORDER BY val ann of [1.0, 2.0, 3.0] LIMIT 1"); assertThat(result).hasSize(1); } + @Test + public void shadowedPrimaryKeyWithSharedVectorAndOtherPredicates() + { + createTable(KEYSPACE, "CREATE TABLE %s (pk int primary key, str_val text, val vector)"); + createIndex("CREATE CUSTOM INDEX ON %s(str_val) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + disableCompaction(KEYSPACE); + + // flush a sstable with one vector that is shared by two rows + execute("INSERT INTO %s (pk, str_val, val) VALUES (0, 'A', [1.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (2, 'A', [1.0, 2.0, 3.0])"); + flush(); + + // flush another sstable to shadow row 0 + execute("DELETE FROM %s where pk = 0"); + flush(); + + // flush another sstable with one new vector row + execute("INSERT INTO %s (pk, str_val, val) VALUES (1, 'A', [2.0, 3.0, 4.0])"); + flush(); + + // the shadowed vector has the highest score, but we shouldn't see it + UntypedResultSet result = execute("SELECT pk FROM %s WHERE str_val = 'A' ORDER BY val ann of [1.0, 2.0, 3.0] LIMIT 2"); + assertRowsIgnoringOrder(result, row(2), row(1)); + } + @Test public void testVectorRowWhereUpdateMakesRowMatchNonOrderingPredicates() { @@ -496,52 +538,386 @@ public void testUpdateNonVectorColumnWhereNoSingleSSTableRowMatchesAllPredicates assertRows(execute("SELECT pk FROM %s WHERE val1 = 'match me' AND val2 = 'match me' ORDER BY vec ANN OF [11,11] LIMIT 2"), row(1), row(2)); } + @Test - public void ensureVariableChunkSizeDoesNotLeadToIncorrectResults() throws Exception + public void shadowedPrimaryKeyWithUpdatedPredicateMatchingIntValue() throws Throwable { - // When adding the chunk size feature, there were issues related to leaked files. - // This setting only matters for hybrid queries - createTable(KEYSPACE, "CREATE TABLE %s (pk int primary key, str_val text, vec vector)"); - createIndex("CREATE CUSTOM INDEX ON %s(vec) USING 'StorageAttachedIndex' WITH OPTIONS = { 'similarity_function' : 'euclidean' }"); + createTable(KEYSPACE, "CREATE TABLE %s (pk int primary key, num int, val vector)"); + createIndex("CREATE CUSTOM INDEX ON %s(num) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + disableCompaction(KEYSPACE); + + // Same PK, different num, different vectors + execute("INSERT INTO %s (pk, num, val) VALUES (0, 1, [1.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, num, val) VALUES (0, 2, [2.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, num, val) VALUES (0, 3, [3.0, 2.0, 3.0])"); + // Need PKs that wrap 0 when put in PK order + execute("INSERT INTO %s (pk, num, val) VALUES (1, 1, [1.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, num, val) VALUES (2, 1, [1.0, 2.0, 3.0])"); + + // the shadowed vector has the highest score, but we shouldn't see it + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT pk FROM %s WHERE num < 3 ORDER BY val ann of [1.0, 2.0, 3.0] LIMIT 10"), + row(1), row(2)); + }); + } + + @Test + public void rangeRestrictedTestWithDuplicateVectorsAndADelete() + { + createTable(String.format("CREATE TABLE %%s (pk int, str_val text, val vector, PRIMARY KEY(pk))", 2)); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (pk, val) VALUES (0, [1.0, 2.0])"); // -3485513579396041028 + execute("INSERT INTO %s (pk, val) VALUES (1, [1.0, 2.0])"); // -4069959284402364209 + execute("INSERT INTO %s (pk, val) VALUES (2, [1.0, 2.0])"); // -3248873570005575792 + execute("INSERT INTO %s (pk, val) VALUES (3, [1.0, 2.0])"); // 9010454139840013625 + + flush(); + + // Show the result set is as expected + assertRows(execute("SELECT pk FROM %s WHERE token(pk) <= -3248873570005575792 AND " + + "token(pk) >= -3485513579396041028 ORDER BY val ann of [1,2] LIMIT 1000"), row(0), row(2)); + + // Delete one of the rows + execute("DELETE FROM %s WHERE pk = 0"); + + flush(); + assertRows(execute("SELECT pk FROM %s WHERE token(pk) <= -3248873570005575792 AND " + + "token(pk) >= -3485513579396041028 ORDER BY val ann of [1,2] LIMIT 1000"), row(2)); + } + + @Test + public void rangeRestrictedTestWithDuplicateVectorsAndAddNullVector() throws Throwable + { + createTable(String.format("CREATE TABLE %%s (pk int, str_val text, val vector, PRIMARY KEY(pk))", 2)); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + + execute("INSERT INTO %s (pk, val) VALUES (0, [1.0, 2.0])"); + execute("INSERT INTO %s (pk, val) VALUES (1, [1.0, 2.0])"); + execute("INSERT INTO %s (pk, val) VALUES (2, [1.0, 2.0])"); + // Add a str_val to make sure pk has a row id in the sstable + execute("INSERT INTO %s (pk, str_val, val) VALUES (3, 'a', null)"); + // Add another row to test a different part of the code + execute("INSERT INTO %s (pk, val) VALUES (4, [1.0, 2.0])"); + execute("DELETE FROM %s WHERE pk = 2"); + flush(); + + // Delete one of the rows to trigger a shadowed primary key + execute("DELETE FROM %s WHERE pk = 0"); + execute("INSERT INTO %s (pk, val) VALUES (2, [2.0, 2.0])"); + flush(); + + // Delete more rows. + execute("DELETE FROM %s WHERE pk = 2"); + execute("DELETE FROM %s WHERE pk = 3"); + + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT pk FROM %s ORDER BY val ann of [1,2] LIMIT 1000"), + row(1), row(4)); + }); + } + + // This test intentionally has extra rows with primary keys that are above and below the + // deleted primary key so that we do not short circuit certain parts of the shadowed key logic. + @Test + public void shadowedPrimaryKeyInDifferentSSTableEachWithMultipleRows() + { + createTable(KEYSPACE, "CREATE TABLE %s (pk int primary key, str_val text, val vector)"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + disableCompaction(KEYSPACE); + + // flush a sstable with one vector + execute("INSERT INTO %s (pk, str_val, val) VALUES (1, 'A', [1.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (2, 'A', [1.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (3, 'A', [1.0, 2.0, 3.0])"); + flush(); + + // flush another sstable to shadow the vector row + execute("INSERT INTO %s (pk, str_val, val) VALUES (1, 'A', [1.0, 2.0, 3.0])"); + execute("DELETE FROM %s where pk = 2"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (3, 'A', [1.0, 2.0, 3.0])"); + flush(); + + // flush another sstable with one new vector row + execute("INSERT INTO %s (pk, str_val, val) VALUES (0, 'B', [2.0, 3.0, 4.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (4, 'B', [2.0, 3.0, 4.0])"); + flush(); + + // the shadow vector has the highest score + UntypedResultSet result = execute("SELECT pk FROM %s ORDER BY val ann of [1.0, 2.0, 3.0] LIMIT 4"); + assertRows(result, row(1), row(3), row(0), row(4)); + } + + @Test + public void shadowedPrimaryKeysRequireDeeperSearch() throws Throwable + { + createTable(KEYSPACE, "CREATE TABLE %s (pk int primary key, str_val text, val vector)"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); createIndex("CREATE CUSTOM INDEX ON %s(str_val) USING 'StorageAttachedIndex'"); + disableCompaction(KEYSPACE); - // Create many sstables to ensure chunk size matters - // Start at 1 to prevent indexing zero vector. - // Index every vector with A to match everything and because this test only makes sense for hybrid queries - for (int i = 1; i <= 100; i++) + // Choose a row count that will essentially force us to re-query the index that still has more rows to search. + int baseRowCount = 1000; + // Create 1000 rows so that each row has a slightly less similar score. + for (int i = 0; i < baseRowCount - 10; i++) { - execute("INSERT INTO %s (pk, str_val, vec) VALUES (?, ?, ?)", i, "A", vector((float) i, (float) i)); - if (i % 10 == 0) - flush(); - // Add some deletes in the next segment - if (i % 3 == 0) - execute("DELETE FROM %s WHERE pk = ?", i); + try + { + execute("INSERT INTO %s (pk, str_val, val) VALUES (?, 'A', ?)", i, vector(1f, (float) i)); + } + catch (Error e) + { + logger.error("Failed to insert row {}", i, e); + throw new RuntimeException(e); + } } - try + for (int i = baseRowCount -10; i < baseRowCount; i++) + execute("INSERT INTO %s (pk, str_val, val) VALUES (?, 'A', ?)", i, vector(1f, (float) -i)); + + flush(); + + // Create 10 rows with the worst scores, but they won't be shadowed. + for (int i = baseRowCount; i < baseRowCount + 10; i++) + execute("INSERT INTO %s (pk, str_val, val) VALUES (?, 'A', ?)", i, vector(-1f, (float) baseRowCount * -1)); + + // Delete all but the last 10 rows + for (int i = 0; i < baseRowCount - 10; i++) + execute("DELETE FROM %s WHERE pk = ?", i); + + beforeAndAfterFlush(() -> { + // ANN Only + assertRows(execute("SELECT pk FROM %s ORDER BY val ann of [1.0, 1.0] LIMIT 3"), + row(baseRowCount - 10), row(baseRowCount - 9), row(baseRowCount - 8)); + // Hyrbid + assertRows(execute("SELECT pk FROM %s WHERE str_val = 'A' ORDER BY val ann of [1.0, 1.0] LIMIT 3"), + row(baseRowCount - 10), row(baseRowCount - 9), row(baseRowCount - 8)); + }); + } + + @Test + public void testUpdateVectorToWorseAndBetterPositions() throws Throwable + { + createTable("CREATE TABLE %s (pk int, val vector, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (pk, val) VALUES (0, [1.0, 2.0])"); + execute("INSERT INTO %s (pk, val) VALUES (1, [1.0, 3.0])"); + + flush(); + execute("INSERT INTO %s (pk, val) VALUES (0, [1.0, 4.0])"); + + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT pk FROM %s ORDER BY val ann of [1.0, 2.0] LIMIT 1"), row(1)); + assertRows(execute("SELECT pk FROM %s ORDER BY val ann of [1.0, 2.0] LIMIT 2"), row(1), row(0)); + }); + + // And now update pk 1 to show that we can get 0 too + execute("INSERT INTO %s (pk, val) VALUES (1, [1.0, 5.0])"); + + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT pk FROM %s ORDER BY val ann of [1.0, 2.0] LIMIT 1"), row(0)); + assertRows(execute("SELECT pk FROM %s ORDER BY val ann of [1.0, 2.0] LIMIT 2"), row(0), row(1)); + }); + + // And now update both PKs so that the stream of ranked rows is PKs: 0, 1, [1], 0, 1, [0], where the numbers + // wrapped in brackets are the "real" scores of the vectors. This test makes sure that we correctly remove + // PrimaryKeys from the updatedKeys map so that we don't accidentally duplicate PKs. + execute("INSERT INTO %s (pk, val) VALUES (1, [1.0, 3.5])"); + execute("INSERT INTO %s (pk, val) VALUES (0, [1.0, 6.0])"); + + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT pk FROM %s ORDER BY val ann of [1.0, 2.0] LIMIT 1"), row(1)); + assertRows(execute("SELECT pk FROM %s ORDER BY val ann of [1.0, 2.0] LIMIT 2"), row(1), row(0)); + }); + } + + @Test + public void updatedPrimaryKeysRequireResumeSearch() throws Throwable + { + createTable(KEYSPACE, "CREATE TABLE %s (pk int primary key, str_val text, val vector)"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(str_val) USING 'StorageAttachedIndex'"); + disableCompaction(KEYSPACE); + + // This test is fairly contrived, but it covers a bug we hit due to prematurely closed iterators. + // The general design for this test is to shadow the close vectors on a memtable/sstable index forcing the + // index to resume search. We do that by overwriting the first 50 vectors in the initial sstable. + for (int i = 0; i < 100; i++) + execute("INSERT INTO %s (pk, str_val, val) VALUES (?, 'A', ?)", i, vector(1f, (float) i)); + + // Add more rows to make sure we filter then sort + for (int i = 100; i < 1000; i++) + execute("INSERT INTO %s (pk, str_val, val) VALUES (?, 'C', ?)", i, vector(1f, (float) i)); + + flush(); + + // Overwrite the most similar 50 rows + for (int i = 0; i < 50; i++) + execute("INSERT INTO %s (pk, str_val, val) VALUES (?, 'B', ?)", i, vector(1f, (float) i)); + + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT pk FROM %s WHERE str_val = 'A' ORDER BY val ann of [1.0, 1.0] LIMIT 1"), + row(50)); + }); + } + + @Test + public void testBruteForceRangeQueryWithUpdatedVectors1536D() throws Throwable + { + testBruteForceRangeQueryWithUpdatedVectors(1536); + } + + @Test + public void testBruteForceRangeQueryWithUpdatedVectors2D() throws Throwable + { + testBruteForceRangeQueryWithUpdatedVectors(2); + } + + private void testBruteForceRangeQueryWithUpdatedVectors(int vectorDimension) throws Throwable + { + createTable("CREATE TABLE %s (pk int, val vector, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + // Insert 100 vectors + for (int i = 0; i < 100; i++) + execute("INSERT INTO %s (pk, val) VALUES (?, ?)", i, randomVectorBoxed(vectorDimension)); + + // Update those vectors so some ordinals are changed + for (int i = 0; i < 100; i++) + execute("INSERT INTO %s (pk, val) VALUES (?, ?)", i, randomVectorBoxed(vectorDimension)); + + // Delete the first 50 PKs. + for (int i = 0; i < 50; i++) + execute("DELETE FROM %s WHERE pk = ?", i); + + // All of the above inserts and deletes are performed on the same index to verify internal index behavior + // for both memtables and sstables. + beforeAndAfterFlush(() -> { + // Query for the first 10 vectors, we don't care which. + // Use a range query to hit the right brute force code path + UntypedResultSet results = execute("SELECT pk FROM %s WHERE token(pk) < 0 ORDER BY val ann of ? LIMIT 10", + randomVectorBoxed(vectorDimension)); + assertThat(results).hasSize(10); + // Make sure we don't get any of the deleted PKs + assertThat(results).allSatisfy(row -> assertThat(row.getInt("pk")).isGreaterThanOrEqualTo(50)); + }); + } + + @Test + public void testVectorIndexWithAllOrdinalsDeletedAndSomeViaRangeDeletion() + { + createTable(KEYSPACE, "CREATE TABLE %s (pk int, a int, str_val text, val vector, PRIMARY KEY(pk, a))"); + createIndex("CREATE CUSTOM INDEX ON %s(str_val) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + disableCompaction(KEYSPACE); + + // Insert two rows with different vectors to get different ordinals + execute("INSERT INTO %s (pk, a, str_val, val) VALUES (1, 1, 'A', [1.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, a, str_val, val) VALUES (2, 1, 'A', [1.0, 2.0, 4.0])"); + + // Range delete the first row + execute("DELETE FROM %s WHERE pk = 1"); + // Specifically delete the vector column second to hit a different code path. + execute("DELETE FROM %s WHERE pk = 2 AND a = 1"); + + // Insert another row without a vector + execute("INSERT INTO %s (pk, a, str_val) VALUES (2, 1, 'A')"); + flush(); + + assertRows(execute("SELECT PK FROM %s WHERE str_val = 'A' ORDER BY val ann of [1.0, 2.0, 3.0] LIMIT 1")); + } + + @Test + public void ensureCompressedVectorsCanFlush() + { + createTable("CREATE TABLE %s (pk int, val vector, PRIMARY KEY(pk))"); + IndexIdentifier indexIdentifier = createIndexIdentifier(createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'")); + + // insert enough vectors for pq plus 1 because we need quantization and we're deleting a row + for (int i = 0; i < MIN_PQ_ROWS + 1; i++) + execute("INSERT INTO %s (pk, val) VALUES (?, ?)", i, vector(randomVector(4))); + + // Delete a single vector to trigger the regression + execute("DELETE from %s WHERE pk = 0"); + + flush(); + + verifySSTableIndexes(indexIdentifier, 1); + } + + // This test mimics having rf > 1. + @Test + public void testSameRowInMultipleSSTablesWithSameTimestamp() throws Throwable + { + createTable("CREATE TABLE %s (pk int, ck int, val vector, PRIMARY KEY(pk, ck))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + // We don't want compaction preventing us from hitting the intended code path. + disableCompaction(); + + // This test is fairly contrived, but covers the case where the first row we attempt to materialize in the + // ScoreOrderedResultRetriever is shadowed by a row in a different sstable. And then, when we go to pull in + // the next row, we find that the PK is already pulled in, so we need to skip it. + execute("INSERT INTO %s (pk, ck, val) VALUES (0, 0, [1.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, ck, val) VALUES (0, 1, [1.0, 2.0, 3.0]) USING TIMESTAMP 1"); + flush(); + // Now, delete row pk=0, ck=0 so that we can test that the shadowed row is not returned and that we need + // to get the next row from the score ordered iterator. + execute("DELETE FROM %s WHERE pk = 0 AND ck = 0"); + execute("INSERT INTO %s (pk, ck, val) VALUES (0, 1, [1.0, 2.0, 3.0]) USING TIMESTAMP 1"); + + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT ck FROM %s ORDER BY val ANN OF [1.0, 2.0, 3.0] LIMIT 2"), row(1)); + }); + } + + @Test + public void testMemtableInsertSearchUpdateSearchHandling() + { + createTable("CREATE TABLE %s (id text PRIMARY KEY, embedding vector)"); + createIndex("CREATE CUSTOM INDEX ON %s(embedding) USING 'StorageAttachedIndex' " + + "WITH OPTIONS = {'similarity_function': 'dot_product'}"); + + // Insert initial data + execute("INSERT INTO %s (id, embedding) VALUES ('row1', [0.1, 0.1, 0.1, 0.1, 0.1])"); + execute("INSERT INTO %s (id, embedding) VALUES ('row2', [0.9, 0.9, 0.9, 0.9, 0.9])"); + + // Query 100 times to try to guarantee all graph searchers are initialized + for (int i = 0; i < 100; i++) { - // We use a chunk size that is as low as possible (1) and goes up to the whole dataset (100). - // We also query for different LIMITs - for (int i = 1; i <= 100; i++) - { - SAI_VECTOR_SEARCH_ORDER_CHUNK_SIZE.setInt(i); - var results = execute("SELECT pk FROM %s WHERE str_val = 'A' ORDER BY vec ANN OF [1,1] LIMIT 1"); - assertRows(results, row(1)); - results = execute("SELECT pk FROM %s WHERE str_val = 'A' ORDER BY vec ANN OF [1,1] LIMIT 3"); - // Note that we delete row 3 - assertRows(results, row(1), row(2), row(4)); - results = execute("SELECT pk FROM %s WHERE str_val = 'A' ORDER BY vec ANN OF [1,1] LIMIT 10"); - // Note that we delete row 3, 6, 9, 12 - assertRows(results, row(1), row(2), row(4), row(5), - row(7), row(8), row(10), row(11), row(13), row(14)); - } + // Initial vector search + UntypedResultSet initialSearch = execute("SELECT * FROM %s ORDER BY embedding ANN OF [0.8, 0.8, 0.8, 0.8, 0.8] LIMIT 1"); + assertThat(initialSearch).hasSize(1); } - finally + + // Update one of the rows (this update wasn't observed due to state leaked between queries previously) + execute("UPDATE %s SET embedding = [0.7, 0.7, 0.7, 0.7, 0.7] WHERE id = 'row1'"); + + // Query 100 times to make sure it works as expected + for (int j = 0; j < 100; j++) { - // Revert to prevent interference with other tests. Note that a decreased chunk size can impact - // whether we compute the topk with brute force because it determines how many vectors get sent to the - // vector index. - SAI_VECTOR_SEARCH_ORDER_CHUNK_SIZE.setInt(100000); + // Get all data to verify we have 2 rows + UntypedResultSet allData = execute("SELECT * FROM %s ORDER BY embedding ANN OF [0.8, 0.8, 0.8, 0.8, 0.8] LIMIT 1000"); + assertThat(allData).hasSize(2); } } + + @Test + public void testUpdatedVectorStaticVectorColumnIndex() throws Throwable + { + createTable("CREATE TABLE %s (pk int, ck int, val vector static, PRIMARY KEY(pk, ck))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + // This counts as an update because the indexed column is static and is therefore operated on at the partition + // level. + execute("INSERT INTO %s (pk, ck, val) VALUES (0, 1, [0,2])"); + execute("INSERT INTO %s (pk, ck, val) VALUES (0, 2, [1,0])"); + execute("INSERT INTO %s (pk, ck, val) VALUES (1, 3, [0,1])"); + + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT ck FROM %s ORDER BY val ANN OF [1,0] LIMIT 2"), row(1), row(2)); + }); + } } diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java index 06309e42a5c8..d7928aaeb5c8 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java @@ -45,6 +45,7 @@ import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.SAIRandomizedTester; +import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSource; @@ -65,6 +66,12 @@ public class InvertedIndexSearcherTest extends SAIRandomizedTester { private final PrimaryKey.Factory primaryKeyFactory = new PrimaryKey.Factory(Murmur3Partitioner.instance, new ClusteringComparator()); + @Override + public SSTableId getSSTableId() + { + return null; + } + @Override public PrimaryKey primaryKeyFromRowId(long sstableRowId) { @@ -194,6 +201,7 @@ private IndexSegmentSearcher buildIndexAndOpenSearcher(StorageAttachedIndex inde try (PerColumnIndexFiles indexFiles = new PerColumnIndexFiles(indexDescriptor, index.termType(), index.identifier())) { final IndexSegmentSearcher searcher = IndexSegmentSearcher.open(TEST_PRIMARY_KEY_MAP_FACTORY, + null, indexFiles, segmentMetadata, index); diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeIndexBuilder.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeIndexBuilder.java index 4f843df3fb84..4947caf2891b 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeIndexBuilder.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/bbtree/BlockBalancedTreeIndexBuilder.java @@ -53,6 +53,7 @@ import org.apache.cassandra.index.sai.utils.IndexTermType; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.TermsIterator; +import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.utils.AbstractGuavaIterator; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; @@ -69,6 +70,12 @@ public class BlockBalancedTreeIndexBuilder { private final PrimaryKey.Factory primaryKeyFactory = new PrimaryKey.Factory(Murmur3Partitioner.instance, null); + @Override + public SSTableId getSSTableId() + { + return null; + } + @Override public PrimaryKey primaryKeyFromRowId(long sstableRowId) { @@ -141,7 +148,7 @@ NumericIndexSegmentSearcher flushAndOpen(AbstractType type) throws IOExceptio try (PerColumnIndexFiles indexFiles = new PerColumnIndexFiles(indexDescriptor, index.termType(), index.identifier())) { - IndexSegmentSearcher searcher = IndexSegmentSearcher.open(TEST_PRIMARY_KEY_MAP_FACTORY, indexFiles, metadata, index); + IndexSegmentSearcher searcher = IndexSegmentSearcher.open(TEST_PRIMARY_KEY_MAP_FACTORY, null, indexFiles, metadata, index); assertThat(searcher, is(instanceOf(NumericIndexSegmentSearcher.class))); return (NumericIndexSegmentSearcher) searcher; } diff --git a/test/unit/org/apache/cassandra/index/sai/functional/FlushingTest.java b/test/unit/org/apache/cassandra/index/sai/functional/FlushingTest.java index d07c774cfee4..e3c462221cde 100644 --- a/test/unit/org/apache/cassandra/index/sai/functional/FlushingTest.java +++ b/test/unit/org/apache/cassandra/index/sai/functional/FlushingTest.java @@ -70,13 +70,13 @@ public void testFlushingOverwriteDelete() ResultSet rows = executeNet("SELECT id1 FROM %s WHERE v1 >= 0"); assertEquals(0, rows.all().size()); verifyIndexFiles(indexTermType, indexIdentifier, sstables, 0, sstables); - verifySSTableIndexes(indexIdentifier, sstables, 0); + verifySSTableIndexes(indexIdentifier, sstables, 0, 3); compact(); waitForAssert(() -> verifyIndexFiles(indexTermType, indexIdentifier, 1, 0, 1)); rows = executeNet("SELECT id1 FROM %s WHERE v1 >= 0"); assertEquals(0, rows.all().size()); - verifySSTableIndexes(indexIdentifier, 1, 0); + verifySSTableIndexes(indexIdentifier, 1, 0, 1); } } diff --git a/test/unit/org/apache/cassandra/index/sai/functional/GroupComponentsTest.java b/test/unit/org/apache/cassandra/index/sai/functional/GroupComponentsTest.java index d4ac27180ebe..851808875f14 100644 --- a/test/unit/org/apache/cassandra/index/sai/functional/GroupComponentsTest.java +++ b/test/unit/org/apache/cassandra/index/sai/functional/GroupComponentsTest.java @@ -32,6 +32,7 @@ import org.apache.cassandra.index.sai.SAITester; import org.apache.cassandra.index.sai.StorageAttachedIndex; import org.apache.cassandra.index.sai.StorageAttachedIndexGroup; +import org.apache.cassandra.index.sai.disk.EmptyIndex; import org.apache.cassandra.index.sai.disk.format.Version; import org.apache.cassandra.index.sai.utils.IndexTermType; import org.apache.cassandra.io.sstable.Component; @@ -62,7 +63,7 @@ public void testInvalidateWithoutObsolete() // index files are released but not removed cfs.invalidate(true, false); - Assert.assertTrue(index.view().getIndexes().isEmpty()); + Assert.assertTrue(index.view().getIndexes().stream().allMatch(i -> i instanceof EmptyIndex)); for (Component component : components) Assert.assertTrue(sstable.descriptor.fileFor(component).exists()); } diff --git a/test/unit/org/apache/cassandra/index/sai/memory/VectorMemoryIndexTest.java b/test/unit/org/apache/cassandra/index/sai/memory/VectorMemoryIndexTest.java index e1be7b276eff..de64a9737c2d 100644 --- a/test/unit/org/apache/cassandra/index/sai/memory/VectorMemoryIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/memory/VectorMemoryIndexTest.java @@ -29,8 +29,6 @@ import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; -import com.google.common.collect.Sets; - import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; @@ -51,6 +49,7 @@ import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.db.marshal.Int32Type; import org.apache.cassandra.db.marshal.VectorType; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.dht.AbstractBounds; import org.apache.cassandra.dht.BootStrapper; import org.apache.cassandra.dht.Bounds; @@ -62,19 +61,18 @@ import org.apache.cassandra.index.sai.SAITester; import org.apache.cassandra.index.sai.StorageAttachedIndex; import org.apache.cassandra.index.sai.disk.format.Version; -import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.plan.Expression; -import org.apache.cassandra.index.sai.utils.PrimaryKey; -import org.apache.cassandra.index.sai.utils.RangeUtil; import org.apache.cassandra.inject.Injections; import org.apache.cassandra.inject.InvokePointBuilder; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.tcm.ClusterMetadata; +import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.FBUtilities; +import org.mockito.Mockito; import static org.apache.cassandra.config.CassandraRelevantProperties.MEMTABLE_SHARD_COUNT; import static org.apache.cassandra.config.CassandraRelevantProperties.ORG_APACHE_CASSANDRA_DISABLE_MBEAN_REGISTRATION; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -134,7 +132,9 @@ public static void reassignLocalTokens() @Test public void randomQueryTest() throws Exception { - memtableIndex = new VectorMemoryIndex(index); + // A non-null memtable tells it to track the mapping from primary key to vector, needed for brute force search + Memtable memtable = Mockito.mock(Memtable.class); + memtableIndex = new VectorMemoryIndex(index, memtable); for (int row = 0; row < getRandom().nextIntBetween(1000, 5000); row++) { @@ -147,6 +147,9 @@ public void randomQueryTest() throws Exception } List keys = new ArrayList<>(keyMap.keySet()); + long actualVectorsReturned = 0; + long expectedVectorsReturned = 0; + double expectedRecall = 0.9; for (int executionCount = 0; executionCount < 1000; executionCount++) { @@ -166,28 +169,41 @@ public void randomQueryTest() throws Exception DataLimits.cqlLimits(limit), DataRange.allData(cfs.metadata().partitioner)); - try (KeyRangeIterator iterator = memtableIndex.search(new QueryContext(command, - DatabaseDescriptor.getRangeRpcTimeout(TimeUnit.MILLISECONDS)), - expression, keyRange)) + long expectedResults = Math.min(limit, keysInRange.size()); + + try (CloseableIterator iterator = memtableIndex.orderBy(new QueryContext(command, + DatabaseDescriptor.getRangeRpcTimeout(TimeUnit.MILLISECONDS)), + expression, keyRange)) { - while (iterator.hasNext()) + PrimaryKeyWithScore lastKey = null; + while (iterator.hasNext() && foundKeys.size() < expectedResults) { - PrimaryKey primaryKey = iterator.next(); - int key = Int32Type.instance.compose(primaryKey.partitionKey().getKey()); + PrimaryKeyWithScore primaryKeyWithScore = iterator.next(); + if (lastKey != null) + // This assertion only holds true as long as we query at most the expectedNumResults. + // Once we query deeper, we might get a key with a higher score than the last key. + // This is a direct consequence of the approximate part of ANN. + // Note that PrimaryKeyWithScore is flipped to descending order, so we use >= here. + assertTrue("Returned keys are not ordered by score", primaryKeyWithScore.compareTo(lastKey) >= 0); + lastKey = primaryKeyWithScore; + int key = Int32Type.instance.compose(primaryKeyWithScore.primaryKey().partitionKey().getKey()); assertFalse(foundKeys.contains(key)); - assertTrue(keyRange.contains(primaryKey.partitionKey())); + assertTrue(keyRange.contains(primaryKeyWithScore.primaryKey().partitionKey())); assertTrue(rowMap.containsKey(key)); foundKeys.add(key); } + // Note that we weight each result evenly instead of each query evenly. + actualVectorsReturned += foundKeys.size(); + expectedVectorsReturned += expectedResults; + if (foundKeys.size() < expectedResults) + assertTrue("Expected at least " + expectedResults + " results but got " + foundKeys.size(), + foundKeys.size() >= expectedResults * expectedRecall); } - // with -Dcassandra.test.random.seed=260652334768666, there is one missing key - long expectedResult = Math.min(limit, keysInRange.size()); - if (RangeUtil.coversFullRing(keyRange)) - assertEquals("Missing key: " + Sets.difference(keysInRange, foundKeys), expectedResult, foundKeys.size()); - else // if skip ANN, returned keys maybe larger than limit - assertTrue("Missing key: " + Sets.difference(keysInRange, foundKeys), expectedResult <= foundKeys.size()); } + + assertTrue("Expected at least " + expectedVectorsReturned + " results but got " + actualVectorsReturned, + actualVectorsReturned >= expectedVectorsReturned * expectedRecall); } @Test From fbb96e149f1f9f79174586d0602f5086441146fe Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Fri, 13 Feb 2026 15:14:56 -0600 Subject: [PATCH 2/2] Address interface changes to Cell and Row --- .../index/sai/utils/CellWithSource.java | 43 +++++++++++++++---- .../index/sai/utils/RowWithSource.java | 10 ++++- 2 files changed, 44 insertions(+), 9 deletions(-) diff --git a/src/java/org/apache/cassandra/index/sai/utils/CellWithSource.java b/src/java/org/apache/cassandra/index/sai/utils/CellWithSource.java index 73316f85e69c..e1d31153da5d 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/CellWithSource.java +++ b/src/java/org/apache/cassandra/index/sai/utils/CellWithSource.java @@ -20,6 +20,10 @@ import java.nio.ByteBuffer; +import javax.annotation.Nonnull; + +import com.google.common.base.Function; + import org.apache.cassandra.db.DeletionPurger; import org.apache.cassandra.db.Digest; import org.apache.cassandra.db.marshal.ValueAccessor; @@ -134,6 +138,12 @@ public Cell withUpdatedValue(ByteBuffer newValue) return wrapIfNew(cell.withUpdatedValue(newValue)); } + @Override + public Cell withUpdatedTimestamp(long newTimestamp) + { + return wrapIfNew(cell.withUpdatedTimestamp(newTimestamp)); + } + @Override public Cell withUpdatedTimestampAndLocalDeletionTime(long newTimestamp, long newLocalDeletionTime) { @@ -191,14 +201,19 @@ public void digest(Digest digest) @Override public ColumnData updateAllTimestamp(long newTimestamp) { - ColumnData maybeNewCell = cell.updateAllTimestamp(newTimestamp); - if (maybeNewCell instanceof Cell) - return wrapIfNew((Cell) maybeNewCell); - if (maybeNewCell instanceof ComplexColumnData) - return ((ComplexColumnData) maybeNewCell).transform(this::wrapIfNew); - // It's not clear when we would hit this code path, but it seems we should not - // hit this from SAI. - throw new IllegalStateException("Expected a Cell instance, but got " + maybeNewCell); + return wrapIfNew(cell.updateAllTimestamp(newTimestamp)); + } + + @Override + public ColumnData updateTimesAndPathsForAccord(@Nonnull Function cellToMaybeNewListPath, long newTimestamp, long newLocalDeletionTime) + { + return wrapIfNew(cell.updateTimesAndPathsForAccord(cellToMaybeNewListPath, newTimestamp, newLocalDeletionTime)); + } + + @Override + public ColumnData updateAllTimesWithNewCellPathForComplexColumnData(@Nonnull CellPath maybeNewPath, long newTimestamp, long newLocalDeletionTime) + { + return wrapIfNew(cell.updateAllTimesWithNewCellPathForComplexColumnData(maybeNewPath, newTimestamp, newLocalDeletionTime)); } @Override @@ -232,6 +247,18 @@ public long maxTimestamp() return cell.maxTimestamp(); } + private ColumnData wrapIfNew(ColumnData maybeNewColumnData) + { + if (maybeNewColumnData instanceof Cell) + return wrapIfNew((Cell) maybeNewColumnData); + if (maybeNewColumnData instanceof ComplexColumnData) + return ((ComplexColumnData) maybeNewColumnData).transform(this::wrapIfNew); + + // It's not clear when we would hit this code path, but it seems we should not + // hit this from SAI. + throw new IllegalStateException("Expected a Cell or ComplexColumnData instance, but got " + maybeNewColumnData); + } + private Cell wrapIfNew(Cell maybeNewCell) { if (maybeNewCell == null) diff --git a/src/java/org/apache/cassandra/index/sai/utils/RowWithSource.java b/src/java/org/apache/cassandra/index/sai/utils/RowWithSource.java index d4cdc2267600..e2766b8cbb4b 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/RowWithSource.java +++ b/src/java/org/apache/cassandra/index/sai/utils/RowWithSource.java @@ -23,8 +23,10 @@ import java.util.Iterator; import java.util.function.BiConsumer; import java.util.function.Consumer; -import java.util.function.Function; +import javax.annotation.Nonnull; + +import com.google.common.base.Function; import com.google.common.collect.Collections2; import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; @@ -276,6 +278,12 @@ public Row updateAllTimestamp(long newTimestamp) return maybeWrapRow(row.updateAllTimestamp(newTimestamp)); } + @Override + public Row updateTimesAndPathsForAccord(@Nonnull Function cellToMaybeNewListPath, long newTimestamp, long newLocalDeletionTime) + { + return maybeWrapRow(row.updateTimesAndPathsForAccord(cellToMaybeNewListPath, newTimestamp, newLocalDeletionTime)); + } + @Override public Row withRowDeletion(DeletionTime deletion) {