diff --git a/docs/changelog/127223.yaml b/docs/changelog/127223.yaml new file mode 100644 index 0000000000000..cc405c9f906d6 --- /dev/null +++ b/docs/changelog/127223.yaml @@ -0,0 +1,5 @@ +pr: 127223 +summary: Wrap ES KNN queries with PatienceKNN query +area: Vector Search +type: feature +issues: [] diff --git a/docs/reference/elasticsearch/index-settings/index-modules.md b/docs/reference/elasticsearch/index-settings/index-modules.md index 4ab35b9d80a88..542c5bc0f811c 100644 --- a/docs/reference/elasticsearch/index-settings/index-modules.md +++ b/docs/reference/elasticsearch/index-settings/index-modules.md @@ -259,3 +259,6 @@ $$$index-esql-stored-fields-sequential-proportion$$$ `index.esql.stored_fields_sequential_proportion` : Tuning parameter for deciding when {{esql}} will load [Stored fields](/reference/elasticsearch/rest-apis/retrieve-selected-fields.md#stored-fields) using a strategy tuned for loading dense sequence of documents. Allows values between 0.0 and 1.0 and defaults to 0.2. Indices with documents smaller than 10kb may see speed improvements loading `text` fields by setting this lower. + +$$$index-dense-vector-hnsw-early-termination$$$ `index.dense_vector.hnsw_early_termination` +: Whether to apply _patience_ based early termination strategy to knn queries over HNSW graphs (see [paper](https://cs.uwaterloo.ca/~jimmylin/publications/Teofili_Lin_ECIR2025.pdf)). This is only applicable to `dense_vector` fields with `hnsw`, `int8_hnsw`, `int4_hnsw` and `bbq_hnsw` index types. Defaults to `false`. diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index 037438069dade..75497c4fcc392 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -48,7 +48,8 @@ record CmdLineArgs( VectorSimilarityFunction vectorSpace, int quantizeBits, VectorEncoding vectorEncoding, - int dimensions + int dimensions, + boolean earlyTermination ) implements ToXContentObject { static final ParseField DOC_VECTORS_FIELD = new ParseField("doc_vectors"); @@ -71,6 +72,7 @@ record CmdLineArgs( static final ParseField QUANTIZE_BITS_FIELD = new ParseField("quantize_bits"); static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding"); static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); + static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination"); static CmdLineArgs fromXContent(XContentParser parser) throws IOException { Builder builder = PARSER.apply(parser, null); @@ -100,6 +102,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD); PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD); PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD); + PARSER.declareBoolean(Builder::setEarlyTermination, EARLY_TERMINATION_FIELD); } @Override @@ -158,6 +161,7 @@ static class Builder { private int quantizeBits = 8; private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32; private int dimensions; + private boolean earlyTermination; public Builder setDocVectors(String docVectors) { this.docVectors = PathUtils.get(docVectors); @@ -259,6 +263,11 @@ public Builder setDimensions(int dimensions) { return this; } + public Builder setEarlyTermination(Boolean patience) { + this.earlyTermination = patience; + return this; + } + public CmdLineArgs build() { if (docVectors == null) { throw new IllegalArgumentException("Document vectors path must be provided"); @@ -288,7 +297,8 @@ public CmdLineArgs build() { vectorSpace, quantizeBits, vectorEncoding, - dimensions + dimensions, + earlyTermination ); } } diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index 25525fe40f92c..1316b4ef8881b 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -211,7 +211,7 @@ public static void main(String[] args) throws Exception { for (int i = 0; i < results.length; i++) { int nProbe = nProbes[i]; KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, nProbe); - knnSearcher.runSearch(results[i]); + knnSearcher.runSearch(results[i], cmdLineArgs.earlyTermination()); } } formattedResults.results.addAll(List.of(results)); diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java index 938bbc0ef8456..7967797e797f9 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java @@ -33,6 +33,9 @@ import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource; import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnByteVectorQuery; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.PatienceKnnVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -114,7 +117,7 @@ class KnnSearcher { this.searchThreads = cmdLineArgs.searchThreads(); } - void runSearch(KnnIndexTester.Results finalResults) throws IOException { + void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) throws IOException { TopDocs[] results = new TopDocs[numQueryVectors]; int[][] resultIds = new int[numQueryVectors][]; long elapsed, totalCpuTimeMS, totalVisited = 0; @@ -153,10 +156,10 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException { for (int i = 0; i < numQueryVectors; i++) { if (vectorEncoding.equals(VectorEncoding.BYTE)) { targetReader.next(targetBytes); - doVectorQuery(targetBytes, searcher); + doVectorQuery(targetBytes, searcher, earlyTermination); } else { targetReader.next(target); - doVectorQuery(target, searcher); + doVectorQuery(target, searcher, earlyTermination); } } targetReader.reset(); @@ -165,10 +168,10 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException { for (int i = 0; i < numQueryVectors; i++) { if (vectorEncoding.equals(VectorEncoding.BYTE)) { targetReader.next(targetBytes); - results[i] = doVectorQuery(targetBytes, searcher); + results[i] = doVectorQuery(targetBytes, searcher, earlyTermination); } else { targetReader.next(target); - results[i] = doVectorQuery(target, searcher); + results[i] = doVectorQuery(target, searcher, earlyTermination); } } KnnIndexTester.ThreadDetails endThreadDetails = new KnnIndexTester.ThreadDetails(); @@ -264,7 +267,7 @@ private boolean isNewer(Path path, Path... others) throws IOException { return true; } - TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher) throws IOException { + TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, boolean earlyTermination) throws IOException { Query knnQuery; if (overSamplingFactor > 1f) { throw new IllegalArgumentException("oversampling factor > 1 is not supported for byte vectors"); @@ -280,6 +283,9 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher) throws IOException null, DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy() ); + if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) { + knnQuery = PatienceKnnVectorQuery.fromByteQuery((KnnByteVectorQuery) knnQuery); + } } QueryProfiler profiler = new QueryProfiler(); TopDocs docs = searcher.search(knnQuery, this.topK); @@ -288,7 +294,7 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher) throws IOException return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs); } - TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException { + TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, boolean earlyTermination) throws IOException { Query knnQuery; int topK = this.topK; if (overSamplingFactor > 1f) { @@ -307,6 +313,9 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException null, DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy() ); + if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) { + knnQuery = PatienceKnnVectorQuery.fromFloatQuery((KnnFloatVectorQuery) knnQuery); + } } if (overSamplingFactor > 1f) { // oversample the topK results to get more candidates for the final result @@ -314,9 +323,12 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException } QueryProfiler profiler = new QueryProfiler(); TopDocs docs = searcher.search(knnQuery, this.topK); - QueryProfilerProvider queryProfilerProvider = (QueryProfilerProvider) knnQuery; - queryProfilerProvider.profile(profiler); - return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs); + if (knnQuery instanceof QueryProfilerProvider queryProfilerProvider) { + queryProfilerProvider.profile(profiler); + return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs); + } else { + return docs; + } } private static float checkResults(int[][] results, int[][] nn, int topK) { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java index 82f63ebbbee12..3dabe1b37b43e 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java @@ -127,4 +127,55 @@ public void testFilteredQueryStrategy() { }); } + public void testHnswEarlyTerminationQuery() { + float[] vector = new float[16]; + randomVector(vector, 25); + int upperLimit = 35; + var query = new KnnSearchBuilder(VECTOR_FIELD, vector, 1, 1, null, null); + assertResponse(client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), response -> { + assertNotEquals(0, response.getHits().getHits().length); + var profileResults = response.getProfileResults(); + long vectorOpsSum = profileResults.values() + .stream() + .mapToLong( + pr -> pr.getQueryPhase() + .getSearchProfileDfsPhaseResult() + .getQueryProfileShardResult() + .stream() + .mapToLong(qpr -> qpr.getVectorOperationsCount().longValue()) + .sum() + ) + .sum(); + client().admin() + .indices() + .prepareUpdateSettings(INDEX_NAME) + .setSettings(Settings.builder().put(DenseVectorFieldMapper.HNSW_EARLY_TERMINATION.getKey(), true)) + .get(); + assertResponse( + client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), + earlyTerminationResponse -> { + assertNotEquals(0, earlyTerminationResponse.getHits().getHits().length); + var earlyTerminationResults = earlyTerminationResponse.getProfileResults(); + long earlyTerminationVectorOpsSum = earlyTerminationResults.values() + .stream() + .mapToLong( + pr -> pr.getQueryPhase() + .getSearchProfileDfsPhaseResult() + .getQueryProfileShardResult() + .stream() + .mapToLong(qpr -> qpr.getVectorOperationsCount().longValue()) + .sum() + ) + .sum(); + assertTrue( + "earlyTerminationVectorOps [" + earlyTerminationVectorOpsSum + "] is not lt vectorOps [" + vectorOpsSum + "]", + earlyTerminationVectorOpsSum < vectorOpsSum + // if both switch to brute-force due to excessive exploration, they will both equal to upperLimit + || (earlyTerminationVectorOpsSum == vectorOpsSum && vectorOpsSum == upperLimit + 1) + ); + } + ); + }); + } + } diff --git a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java index 0ff64f14dc17c..796b03211432b 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java +++ b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java @@ -159,6 +159,7 @@ public final class IndexScopedSettings extends AbstractScopedSettings { IndexSettings.INDEX_TRANSLOG_RETENTION_SIZE_SETTING, IndexSettings.INDEX_SEARCH_IDLE_AFTER, DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC, + DenseVectorFieldMapper.HNSW_EARLY_TERMINATION, IndexFieldDataService.INDEX_FIELDDATA_CACHE_KEY, IndexSettings.IGNORE_ABOVE_SETTING, FieldMapper.IGNORE_MALFORMED_SETTING, diff --git a/server/src/main/java/org/elasticsearch/index/IndexSettings.java b/server/src/main/java/org/elasticsearch/index/IndexSettings.java index 4b89ab2e60021..eac2ef3d42b61 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexSettings.java +++ b/server/src/main/java/org/elasticsearch/index/IndexSettings.java @@ -916,6 +916,7 @@ private void setRetentionLeaseMillis(final TimeValue retentionLease) { private volatile int maxNgramDiff; private volatile int maxShingleDiff; private volatile DenseVectorFieldMapper.FilterHeuristic hnswFilterHeuristic; + private volatile boolean earlyTermination; private volatile TimeValue searchIdleAfter; private volatile int maxAnalyzedOffset; private volatile boolean weightMatchesEnabled; @@ -1113,6 +1114,7 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti skipIgnoredSourceWrite = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_WRITE_SETTING); skipIgnoredSourceRead = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING); hnswFilterHeuristic = scopedSettings.get(DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC); + earlyTermination = scopedSettings.get(DenseVectorFieldMapper.HNSW_EARLY_TERMINATION); indexMappingSourceMode = scopedSettings.get(INDEX_MAPPER_SOURCE_MODE_SETTING); recoverySourceEnabled = RecoverySettings.INDICES_RECOVERY_SOURCE_ENABLED_SETTING.get(nodeSettings); recoverySourceSyntheticEnabled = DiscoveryNode.isStateless(nodeSettings) == false @@ -1227,6 +1229,7 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti ); scopedSettings.addSettingsUpdateConsumer(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING, this::setSkipIgnoredSourceRead); scopedSettings.addSettingsUpdateConsumer(DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC, this::setHnswFilterHeuristic); + scopedSettings.addSettingsUpdateConsumer(DenseVectorFieldMapper.HNSW_EARLY_TERMINATION, this::setHnswEarlyTermination); } private void setSearchIdleAfter(TimeValue searchIdleAfter) { @@ -1858,6 +1861,14 @@ private void setHnswFilterHeuristic(DenseVectorFieldMapper.FilterHeuristic heuri this.hnswFilterHeuristic = heuristic; } + public boolean getHnswEarlyTermination() { + return this.earlyTermination; + } + + private void setHnswEarlyTermination(boolean earlyTermination) { + this.earlyTermination = earlyTermination; + } + public SeqNoFieldMapper.SeqNoIndexOptions seqNoIndexOptions() { return seqNoIndexOptions; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 819d9608f2348..b21724d0b1c61 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -33,6 +33,9 @@ import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.FieldExistsQuery; +import org.apache.lucene.search.KnnByteVectorQuery; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.PatienceKnnVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.knn.KnnSearchStrategy; @@ -123,6 +126,7 @@ public class DenseVectorFieldMapper extends FieldMapper { private static final float EPS = 1e-3f; public static final int BBQ_MIN_DIMS = 64; + private static final boolean DEFAULT_HNSW_EARLY_TERMINATION = false; public static final FeatureFlag IVF_FORMAT = new FeatureFlag("ivf_format"); public static boolean isNotUnitVector(float magnitude) { @@ -174,6 +178,14 @@ public KnnSearchStrategy getKnnSearchStrategy() { Setting.Property.Dynamic ); + public static final Setting HNSW_EARLY_TERMINATION = Setting.boolSetting( + "index.dense_vector.hnsw_enable_early_termination", + DEFAULT_HNSW_EARLY_TERMINATION, + Setting.Property.IndexScope, + Setting.Property.ServerlessPublic, + Setting.Property.Dynamic + ); + private static boolean hasRescoreIndexVersion(IndexVersion version) { return version.onOrAfter(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS) || version.between(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS_BACKPORT_8_X, IndexVersions.UPGRADE_TO_LUCENE_10_0_0); @@ -212,7 +224,7 @@ private static boolean defaultOversampleForBBQ(IndexVersion version) { public static final int MAX_DIMS_COUNT_BIT = 4096 * Byte.SIZE; // maximum allowed number of dimensions public static final short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to - // vector + // vector public static final int MAGNITUDE_BYTES = 4; public static final int OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed public static final float DEFAULT_OVERSAMPLE = 3.0F; // Default oversample value @@ -1429,6 +1441,7 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map createKnnFloatQuery( queryVector.asFloatVector(), @@ -2513,7 +2528,8 @@ public Query createKnnQuery( filter, similarityThreshold, parentFilter, - knnSearchStrategy + knnSearchStrategy, + hnswEarlyTermination ); case BIT -> createKnnBitQuery( queryVector.asByteVector(), @@ -2522,7 +2538,8 @@ public Query createKnnQuery( filter, similarityThreshold, parentFilter, - knnSearchStrategy + knnSearchStrategy, + hnswEarlyTermination ); }; } @@ -2542,7 +2559,8 @@ private Query createKnnBitQuery( Query filter, Float similarityThreshold, BitSetProducer parentFilter, - KnnSearchStrategy searchStrategy + KnnSearchStrategy searchStrategy, + boolean hnswEarlyTermination ) { elementType.checkDimensions(dims, queryVector.length); Query knnQuery; @@ -2559,6 +2577,9 @@ private Query createKnnBitQuery( knnQuery = parentFilter != null ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); + if (hnswEarlyTermination) { + knnQuery = maybeWrapPatience(knnQuery); + } } if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( @@ -2577,7 +2598,8 @@ private Query createKnnByteQuery( Query filter, Float similarityThreshold, BitSetProducer parentFilter, - KnnSearchStrategy searchStrategy + KnnSearchStrategy searchStrategy, + boolean hnswEarlyTermination ) { elementType.checkDimensions(dims, queryVector.length); @@ -2585,7 +2607,6 @@ private Query createKnnByteQuery( float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); } - Query knnQuery; if (indexOptions != null && indexOptions.isFlat()) { var exactKnnQuery = parentFilter != null @@ -2600,6 +2621,9 @@ private Query createKnnByteQuery( knnQuery = parentFilter != null ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); + if (hnswEarlyTermination) { + knnQuery = maybeWrapPatience(knnQuery); + } } if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( @@ -2611,6 +2635,23 @@ private Query createKnnByteQuery( return knnQuery; } + private Query maybeWrapPatience(Query knnQuery) { + Query finalQuery = knnQuery; + if (knnQuery instanceof KnnByteVectorQuery knnByteVectorQuery && canApplyPatienceQuery()) { + finalQuery = PatienceKnnVectorQuery.fromByteQuery(knnByteVectorQuery); + } else if (knnQuery instanceof KnnFloatVectorQuery knnFloatVectorQuery && canApplyPatienceQuery()) { + finalQuery = PatienceKnnVectorQuery.fromFloatQuery(knnFloatVectorQuery); + } + return finalQuery; + } + + private boolean canApplyPatienceQuery() { + return indexOptions instanceof HnswIndexOptions + || indexOptions instanceof Int8HnswIndexOptions + || indexOptions instanceof Int4HnswIndexOptions + || indexOptions instanceof BBQHnswIndexOptions; + } + private Query createKnnFloatQuery( float[] queryVector, int k, @@ -2619,7 +2660,8 @@ private Query createKnnFloatQuery( Query filter, Float similarityThreshold, BitSetProducer parentFilter, - KnnSearchStrategy knnSearchStrategy + KnnSearchStrategy knnSearchStrategy, + boolean hnswEarlyTermination ) { elementType.checkDimensions(dims, queryVector.length); elementType.checkVectorBounds(queryVector); @@ -2686,6 +2728,9 @@ && isNotUnitVector(squaredMagnitude)) { knnSearchStrategy ) : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy); + if (hnswEarlyTermination) { + knnQuery = maybeWrapPatience(knnQuery); + } } if (rescore) { knnQuery = RescoreKnnVectorQuery.fromInnerQuery( @@ -2808,7 +2853,6 @@ public void parse(DocumentParserContext context) throws IOException { } if (fieldType().dims == null) { int dims = fieldType().elementType.parseDimensionCount(context); - ; final boolean defaultInt8Hnsw = indexCreatedVersion.onOrAfter(IndexVersions.DEFAULT_DENSE_VECTOR_TO_INT8_HNSW); final boolean defaultBBQ8Hnsw = indexCreatedVersion.onOrAfter(IndexVersions.DEFAULT_DENSE_VECTOR_TO_BBQ_HNSW); DenseVectorIndexOptions denseVectorIndexOptions = fieldType().indexOptions; diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 87f9a50c64c17..ea0c15642eb74 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -553,6 +553,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { } } DenseVectorFieldMapper.FilterHeuristic heuristic = context.getIndexSettings().getHnswFilterHeuristic(); + boolean hnswEarlyTermination = context.getIndexSettings().getHnswEarlyTermination(); return vectorFieldType.createKnnQuery( queryVector, k, @@ -561,7 +562,8 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { filterQuery, vectorSimilarity, parentBitSet, - heuristic + heuristic, + hnswEarlyTermination ); } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 8f17dfa8fd56e..0f161d4a1e44f 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -2429,7 +2429,8 @@ public void testByteVectorQueryBoundaries() throws IOException { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat( @@ -2447,7 +2448,8 @@ public void testByteVectorQueryBoundaries() throws IOException { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat( @@ -2465,7 +2467,8 @@ public void testByteVectorQueryBoundaries() throws IOException { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat( @@ -2483,7 +2486,8 @@ public void testByteVectorQueryBoundaries() throws IOException { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat( @@ -2501,7 +2505,8 @@ public void testByteVectorQueryBoundaries() throws IOException { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat(e.getMessage(), containsString("element_type [byte] vectors do not support NaN values but found [NaN] at dim [0];")); @@ -2516,7 +2521,8 @@ public void testByteVectorQueryBoundaries() throws IOException { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat( @@ -2534,7 +2540,8 @@ public void testByteVectorQueryBoundaries() throws IOException { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat( @@ -2569,7 +2576,8 @@ public void testFloatVectorQueryBoundaries() throws IOException { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat(e.getMessage(), containsString("element_type [float] vectors do not support NaN values but found [NaN] at dim [0];")); @@ -2584,7 +2592,8 @@ public void testFloatVectorQueryBoundaries() throws IOException { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat( @@ -2602,7 +2611,8 @@ public void testFloatVectorQueryBoundaries() throws IOException { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat( diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index bcb4214a23f3d..2524422ed8f90 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -9,8 +9,8 @@ package org.elasticsearch.index.mapper.vectors; -import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.PatienceKnnVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; @@ -234,7 +234,8 @@ public void testCreateNestedKnnQuery() { null, null, producer, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ); if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { query = rescoreKnnVectorQuery.innerQuery(); @@ -242,7 +243,7 @@ public void testCreateNestedKnnQuery() { if (field.getIndexOptions().isFlat()) { assertThat(query, instanceOf(DiversifyingParentBlockQuery.class)); } else { - assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class)); + assertTrue(query instanceof DiversifyingChildrenFloatKnnVectorQuery || query instanceof PatienceKnnVectorQuery); } } { @@ -272,12 +273,13 @@ public void testCreateNestedKnnQuery() { null, null, producer, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ); if (field.getIndexOptions().isFlat()) { assertThat(query, instanceOf(DiversifyingParentBlockQuery.class)); } else { - assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); + assertTrue(query instanceof DiversifyingChildrenByteKnnVectorQuery || query instanceof PatienceKnnVectorQuery); } vectorData = new VectorData(floatQueryVector, null); @@ -289,12 +291,13 @@ public void testCreateNestedKnnQuery() { null, null, producer, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ); if (field.getIndexOptions().isFlat()) { assertThat(query, instanceOf(DiversifyingParentBlockQuery.class)); } else { - assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); + assertTrue(query instanceof DiversifyingChildrenByteKnnVectorQuery || query instanceof PatienceKnnVectorQuery); } } } @@ -366,7 +369,8 @@ public void testFloatCreateKnnQuery() { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); @@ -396,7 +400,8 @@ public void testFloatCreateKnnQuery() { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors.")); @@ -422,7 +427,8 @@ public void testFloatCreateKnnQuery() { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); @@ -453,7 +459,8 @@ public void testCreateKnnQueryMaxDims() { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ); if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { query = rescoreKnnVectorQuery.innerQuery(); @@ -461,7 +468,7 @@ public void testCreateKnnQueryMaxDims() { if (fieldWith4096dims.getIndexOptions().isFlat()) { assertThat(query, instanceOf(DenseVectorQuery.Floats.class)); } else { - assertThat(query, instanceOf(KnnFloatVectorQuery.class)); + assertTrue(query instanceof KnnFloatVectorQuery || query instanceof PatienceKnnVectorQuery); } } @@ -490,12 +497,13 @@ public void testCreateKnnQueryMaxDims() { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ); if (fieldWith4096dims.getIndexOptions().isFlat()) { assertThat(query, instanceOf(DenseVectorQuery.Bytes.class)); } else { - assertThat(query, instanceOf(KnnByteVectorQuery.class)); + assertTrue(query instanceof ESKnnByteVectorQuery || query instanceof PatienceKnnVectorQuery); } } } @@ -522,7 +530,8 @@ public void testByteCreateKnnQuery() { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); @@ -548,7 +557,8 @@ public void testByteCreateKnnQuery() { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); @@ -563,7 +573,8 @@ public void testByteCreateKnnQuery() { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); @@ -591,24 +602,33 @@ public void testRescoreOversampleUsedWithoutQuantization() { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ); if (elementType == BYTE) { if (nonQuantizedField.getIndexOptions().isFlat()) { assertThat(knnQuery, instanceOf(DenseVectorQuery.Bytes.class)); } else { - ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery; - assertThat(esKnnQuery.getK(), is(100)); - assertThat(esKnnQuery.kParam(), is(10)); + if (knnQuery instanceof PatienceKnnVectorQuery patienceKnnVectorQuery) { + assertThat(patienceKnnVectorQuery.getK(), is(100)); + } else { + ESKnnByteVectorQuery knnByteVectorQuery = (ESKnnByteVectorQuery) knnQuery; + assertThat(knnByteVectorQuery.getK(), is(100)); + assertThat(knnByteVectorQuery.kParam(), is(10)); + } } } else { if (nonQuantizedField.getIndexOptions().isFlat()) { assertThat(knnQuery, instanceOf(DenseVectorQuery.Floats.class)); } else { - ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery; - assertThat(esKnnQuery.getK(), is(100)); - assertThat(esKnnQuery.kParam(), is(10)); + if (knnQuery instanceof PatienceKnnVectorQuery patienceKnnVectorQuery) { + assertThat(patienceKnnVectorQuery.getK(), is(100)); + } else { + ESKnnFloatVectorQuery knnFloatVectorQuery = (ESKnnFloatVectorQuery) knnQuery; + assertThat(knnFloatVectorQuery.getK(), is(100)); + assertThat(knnFloatVectorQuery.kParam(), is(10)); + } } } } @@ -655,12 +675,13 @@ public void testRescoreOversampleQueryOverrides() { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ); if (fieldType.getIndexOptions().isFlat()) { assertThat(query, instanceOf(DenseVectorQuery.Floats.class)); } else { - assertThat(query, instanceOf(ESKnnFloatVectorQuery.class)); + assertTrue(query instanceof ESKnnFloatVectorQuery || query instanceof PatienceKnnVectorQuery); } // verify we can override a `0` to a positive number @@ -683,20 +704,23 @@ public void testRescoreOversampleQueryOverrides() { null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ); assertTrue(query instanceof RescoreKnnVectorQuery); - assertThat(((RescoreKnnVectorQuery) query).k(), equalTo(10)); - ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) ((RescoreKnnVectorQuery) query).innerQuery(); - assertThat(esKnnQuery.kParam(), equalTo(20)); - + RescoreKnnVectorQuery rescoreKnnVectorQuery = (RescoreKnnVectorQuery) query; + assertThat(rescoreKnnVectorQuery.k(), equalTo(10)); + Query innerQuery = rescoreKnnVectorQuery.innerQuery(); + if (innerQuery instanceof ESKnnFloatVectorQuery esKnnFloatVectorQuery) { + assertThat(esKnnFloatVectorQuery.kParam(), equalTo(20)); + } } public void testFilterSearchThreshold() { List>> cases = List.of( - Tuple.tuple(FLOAT, q -> ((ESKnnFloatVectorQuery) q).getStrategy()), - Tuple.tuple(BYTE, q -> ((ESKnnByteVectorQuery) q).getStrategy()), - Tuple.tuple(BIT, q -> ((ESKnnByteVectorQuery) q).getStrategy()) + Tuple.tuple(FLOAT, q -> q instanceof PatienceKnnVectorQuery ? null : ((ESKnnFloatVectorQuery) q).getStrategy()), + Tuple.tuple(BYTE, q -> q instanceof PatienceKnnVectorQuery ? null : ((ESKnnByteVectorQuery) q).getStrategy()), + Tuple.tuple(BIT, q -> q instanceof PatienceKnnVectorQuery ? null : ((ESKnnByteVectorQuery) q).getStrategy()) ); for (var tuple : cases) { DenseVectorFieldType fieldType = new DenseVectorFieldType( @@ -720,25 +744,31 @@ public void testFilterSearchThreshold() { null, null, null, - DenseVectorFieldMapper.FilterHeuristic.FANOUT + DenseVectorFieldMapper.FilterHeuristic.FANOUT, + randomBoolean() ); KnnSearchStrategy strategy = tuple.v2().apply(query); - assertTrue(strategy instanceof KnnSearchStrategy.Hnsw); - assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(0)); - - query = fieldType.createKnnQuery( - VectorData.fromFloats(new float[] { 1, 4, 10 }), - 10, - 100, - 0f, - null, - null, - null, - DenseVectorFieldMapper.FilterHeuristic.ACORN - ); - strategy = tuple.v2().apply(query); - assertTrue(strategy instanceof KnnSearchStrategy.Hnsw); - assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(60)); + if (strategy != null) { + assertTrue(strategy instanceof KnnSearchStrategy.Hnsw); + assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(0)); + + query = fieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 1, 4, 10 }), + 10, + 100, + 0f, + null, + null, + null, + DenseVectorFieldMapper.FilterHeuristic.ACORN, + randomBoolean() + ); + strategy = tuple.v2().apply(query); + if (strategy != null) { + assertThat(strategy, instanceOf(KnnSearchStrategy.Hnsw.class)); + assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(60)); + } + } } } @@ -759,12 +789,18 @@ private static void checkRescoreQueryParameters( null, null, null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()), + randomBoolean() ); RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; - ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery(); - assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); - assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK)); - assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates)); + Query innerQuery = rescoreQuery.innerQuery(); + if (innerQuery instanceof PatienceKnnVectorQuery patienceKnnVectorQuery) { + assertThat("Unexpected candidates", patienceKnnVectorQuery.getK(), equalTo(expectedCandidates)); + } else { + ESKnnFloatVectorQuery knnQuery = (ESKnnFloatVectorQuery) innerQuery; + assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); + assertThat("Unexpected candidates", knnQuery.getK(), equalTo(expectedCandidates)); + assertThat("Unexpected k parameter", knnQuery.kParam(), equalTo(expectedK)); + } } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java index 320b3efc4924c..8ff81cda6e8a0 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java @@ -120,7 +120,8 @@ public void testRandom() throws IOException { null, null, bitSetproducer, - DenseVectorFieldMapper.FilterHeuristic.ACORN + DenseVectorFieldMapper.FilterHeuristic.ACORN, + randomBoolean() ); assertThat(knnQuery, instanceOf(DiversifyingParentBlockQuery.class)); var nestedQuery = new ToParentBlockJoinQuery(knnQuery, bitSetproducer, ScoreMode.Total); diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowAction.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowAction.java index 18e125a7ae1ce..b0be3e21bbc7c 100644 --- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowAction.java +++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowAction.java @@ -531,7 +531,8 @@ static String[] extractLeaderShardHistoryUUIDs(Map ccrIndexMetad DataTier.TIER_PREFERENCE_SETTING, IndexSettings.BLOOM_FILTER_ID_FIELD_ENABLED_SETTING, MetadataIndexStateService.VERIFIED_READ_ONLY_SETTING, - DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC + DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC, + DenseVectorFieldMapper.HNSW_EARLY_TERMINATION ); public static Settings filter(Settings originalSettings) {