From 7f313f471e974ef9e04775530baed80f138a58bc Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 18 Jul 2025 18:42:19 +0200 Subject: [PATCH 01/24] Add exact NN query infra --- .../function/fulltext/FullTextFunction.java | 4 +- .../expression/function/vector/ExactNN.java | 224 ++++++++++++++++++ .../esql/querydsl/query/ExactNNQuery.java | 69 ++++++ 3 files changed, 295 insertions(+), 2 deletions(-) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/ExactNNQuery.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java index b5378db783f46..8f553423bd17f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java @@ -377,7 +377,7 @@ public static void fieldVerifier(LogicalPlan plan, FullTextFunction function, Ex } @Override - public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { + public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { List shardContexts = toEvaluator.shardContexts(); ShardConfig[] shardConfigs = new ShardConfig[shardContexts.size()]; int i = 0; @@ -388,7 +388,7 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua } @Override - public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) { + public final ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) { List shardContexts = toScorer.shardContexts(); ShardConfig[] shardConfigs = new ShardConfig[shardContexts.size()]; int i = 0; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java new file mode 100644 index 0000000000000..70c8e4abad7f6 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java @@ -0,0 +1,224 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware; +import org.elasticsearch.xpack.esql.common.Failures; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.querydsl.query.Query; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.Check; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction; +import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.planner.TranslatorHandler; +import org.elasticsearch.xpack.esql.querydsl.query.ExactNNQuery; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.function.BiConsumer; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNumeric; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; + +/** + * Exact nearest neighbour search using a dense_vector similarity function. Used to translate {@link Knn} into exact search + * when it can't be pushed down to Lucene. Not exposed to users directly. + */ +public class ExactNN extends FullTextFunction implements OptionalArgument, VectorFunction, PostAnalysisPlanVerificationAware { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", ExactNN::readFrom); + + private final Expression field; + private final Expression minimumSimilarity; + + @FunctionInfo( + returnType = "boolean", + preview = true, + description = "Finds all nearest vectors to a query vector, as measured by a similarity metric. " + + "performs brute force search over all vectors in the index.", + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) } + ) + public ExactNN( + Source source, + @Param(name = "field", type = { "dense_vector" }, description = "Field that the query will target.") Expression field, + @Param( + name = "query", + type = { "dense_vector" }, + description = "Vector value to find top nearest neighbours for." + ) Expression query, + @Param( + name = "similarity", + type = { "double" }, + optional = true, + description = "The minimum similarity required for a document to be considered a match. " + + "The similarity value calculated relates to the raw similarity used, not the document score." + ) + Expression minimumSimilarity + ) { + this(source, field, query, minimumSimilarity, null); + } + + public ExactNN( + Source source, + Expression field, + Expression query, + Expression minimumSimilarity, + QueryBuilder queryBuilder + ) { + super(source, query, minimumSimilarity == null ? List.of(field, query) : List.of(field, query, minimumSimilarity), queryBuilder); + this.field = field; + this.minimumSimilarity = minimumSimilarity; + } + + public Expression field() { + return field; + } + + public Expression minimumSimilarity() { + return minimumSimilarity; + } + + @Override + public DataType dataType() { + return DataType.BOOLEAN; + } + + @Override + protected TypeResolution resolveParams() { + return isNotNull(field(), sourceText(), FIRST).and(isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector")); + } + + private TypeResolution resolveField() { + return isNotNull(field(), sourceText(), FIRST).and(isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector")); + } + + private TypeResolution resolveQuery() { + return isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector").and( + isNotNullAndFoldable(query(), sourceText(), SECOND) + ); + } + + private TypeResolution resolveMinimumSimilarity() { + if (minimumSimilarity == null) { + return TypeResolution.TYPE_RESOLVED; + } + + return isNotNull(minimumSimilarity(), sourceText(), THIRD) + .and(isNumeric(minimumSimilarity(), sourceText(), THIRD)); + } + + @Override + public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { + return new ExactNN(source(), field(), query(), minimumSimilarity(), queryBuilder); + } + + @Override + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { + var fieldAttribute = Match.fieldAsFieldAttribute(field()); + + Check.notNull(fieldAttribute, "Exact must have a field attribute as the first argument"); + String fieldName = getNameFromFieldAttribute(fieldAttribute); + @SuppressWarnings("unchecked") + List queryFolded = (List) query().fold(FoldContext.small() /* TODO remove me */); + float[] queryAsFloats = new float[queryFolded.size()]; + for (int i = 0; i < queryFolded.size(); i++) { + queryAsFloats[i] = queryFolded.get(i).floatValue(); + } + Float similarity = minimumSimilarity != null ? ((Number) minimumSimilarity().fold(FoldContext.small())).floatValue() : null; + + return new ExactNNQuery(source(), fieldName, queryAsFloats, similarity); + } + + @Override + public BiConsumer postAnalysisPlanVerification() { + return (plan, failures) -> { + super.postAnalysisPlanVerification().accept(plan, failures); + fieldVerifier(plan, this, field, failures); + }; + } + + @Override + public Expression replaceChildren(List newChildren) { + return new ExactNN( + source(), + newChildren.get(0), + newChildren.get(1), + newChildren.size() > 2 ? newChildren.get(2) : null, + queryBuilder() + ); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, ExactNN::new, field(), query(), minimumSimilarity(), queryBuilder()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + private static ExactNN readFrom(StreamInput in) throws IOException { + Source source = Source.readFrom((PlanStreamInput) in); + Expression field = in.readNamedWriteable(Expression.class); + Expression query = in.readNamedWriteable(Expression.class); + Expression minimumSimilarity = in.readNamedWriteable(Expression.class); + QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class); + return new ExactNN(source, field, query, minimumSimilarity, queryBuilder); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteable(field()); + out.writeNamedWriteable(query()); + out.writeNamedWriteable(minimumSimilarity()); + out.writeOptionalNamedWriteable(queryBuilder()); + } + + @Override + public boolean equals(Object o) { + // Knn does not serialize options, as they get included in the query builder. We need to override equals and hashcode to + // ignore options when comparing two Knn functions + if (o == null || getClass() != o.getClass()) return false; + ExactNN knn = (ExactNN) o; + return Objects.equals(field(), knn.field()) + && Objects.equals(query(), knn.query()) + && Objects.equals(minimumSimilarity(), knn.minimumSimilarity()) + && Objects.equals(queryBuilder(), knn.queryBuilder()); + } + + @Override + public int hashCode() { + return Objects.hash(field(), query(), minimumSimilarity(), queryBuilder()); + } + +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/ExactNNQuery.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/ExactNNQuery.java new file mode 100644 index 0000000000000..9df73c45604fe --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/ExactNNQuery.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.querydsl.query; + +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.vectors.ExactKnnQueryBuilder; +import org.elasticsearch.search.vectors.VectorData; +import org.elasticsearch.xpack.esql.core.querydsl.query.Query; +import org.elasticsearch.xpack.esql.core.tree.Source; + +import java.util.Arrays; +import java.util.Objects; + +public class ExactNNQuery extends Query { + + private final String field; + private final float[] query; + private final Float minimumSimilarity; + + public static final String RESCORE_OVERSAMPLE_FIELD = "rescore_oversample"; + + public ExactNNQuery(Source source, String field, float[] query, Float minimumSimilarity) { + super(source); + this.field = field; + this.query = query; + this.minimumSimilarity = minimumSimilarity; + } + + @Override + protected QueryBuilder asBuilder() { + return new ExactKnnQueryBuilder(VectorData.fromFloats(query), field, minimumSimilarity); + } + + @Override + protected String innerToString() { + return "exactNN(" + field + ", " + Arrays.toString(query) + " minimumSimilarity=" + minimumSimilarity + ")"; + } + + @Override + public boolean equals(Object o) { + if (super.equals(o) == false) return false; + + if (o == null || getClass() != o.getClass()) return false; + ExactNNQuery knnQuery = (ExactNNQuery) o; + return Objects.equals(field, knnQuery.field) + && Objects.deepEquals(query, knnQuery.query) + && Objects.equals(minimumSimilarity, knnQuery.minimumSimilarity) + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), field, Arrays.hashCode(query), minimumSimilarity); + } + + @Override + public boolean scorable() { + return true; + } + + @Override + public boolean containsPlan() { + return false; + } +} From 2b1a4fa789ffda3700c89b22fc0964a0f7f5a907 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 21 Jul 2025 21:52:08 +0200 Subject: [PATCH 02/24] First version ReplaceKnnWithNoPushedDownFiltersWithEvalTopN --- .../xpack/esql/plugin/KnnFunctionIT.java | 52 +++++- .../expression/function/vector/ExactNN.java | 7 +- .../esql/expression/function/vector/Knn.java | 5 + .../optimizer/LocalLogicalPlanOptimizer.java | 4 +- ...nnWithNoPushedDownFiltersWithEvalTopN.java | 154 ++++++++++++++++++ .../esql/querydsl/query/ExactNNQuery.java | 8 +- 6 files changed, 220 insertions(+), 10 deletions(-) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFiltersWithEvalTopN.java diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index 9ae1c980337f1..07b9825e9b909 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.esql.EsqlTestUtils; @@ -32,6 +33,7 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.hamcrest.CoreMatchers.containsString; +@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE", reason = "debug") public class KnnFunctionIT extends AbstractEsqlIntegTestCase { private final Map> indexedVectors = new HashMap<>(); @@ -157,6 +159,50 @@ public void testKnnWithLookupJoin() { ); } + public void testKnnNotPushedDown() { + float[] queryVector = new float[numDims]; + Arrays.fill(queryVector, 1.0f); + + // We retrieve 5 from knn, and 5 from the non-pushed down disjunction. They are disjoint so we get 10 as a result + var query = String.format(Locale.ROOT, """ + FROM test + | WHERE knn(vector, %s, 5) OR (length(keyword) > 5 AND length(keyword) <= 10) + | KEEP id, vector, keyword + | SORT id ASC + | LIMIT 20 + """, Arrays.toString(queryVector)); + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "vector", "keyword")); + assertColumnTypes(resp.columns(), List.of("integer", "dense_vector", "keyword")); + + List> valuesList = EsqlTestUtils.getValuesList(resp); + assertEquals(10, valuesList.size()); + } + } + + public void testKnnPrefiltersNotPushedDown() { + float[] queryVector = new float[numDims]; + Arrays.fill(queryVector, 1.0f); + + // We retrieve 5 from knn, but must be prefiltered with the non-pushed down conjunction + var query = String.format(Locale.ROOT, """ + FROM test + | WHERE knn(vector, %s, 5) AND length(keyword) > 5 AND length(keyword) <= 10 + | KEEP id, vector, keyword + | SORT id ASC + | LIMIT 20 + """, Arrays.toString(queryVector)); + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "vector", "keyword")); + assertColumnTypes(resp.columns(), List.of("integer", "dense_vector", "keyword")); + + List> valuesList = EsqlTestUtils.getValuesList(resp); + assertEquals(5, valuesList.size()); + } + } + @Before public void setup() throws IOException { assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); @@ -176,6 +222,9 @@ public void setup() throws IOException { .startObject("floats") .field("type", "float") .endObject() + .startObject("keyword") + .field("type", "keyword") + .endObject() .endObject() .endObject(); @@ -195,7 +244,8 @@ public void setup() throws IOException { for (int j = 0; j < numDims; j++) { vector.add(value++); } - docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "floats", vector, "vector", vector); + docs[i] = prepareIndex("test").setId("" + i) + .setSource("id", String.valueOf(i), "floats", vector, "vector", vector, "keyword", randomAlphaOfLength(i)); indexedVectors.put(i, vector); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java index 70c8e4abad7f6..0935d3e969af3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java @@ -43,7 +43,6 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNumeric; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; @@ -113,7 +112,7 @@ public DataType dataType() { @Override protected TypeResolution resolveParams() { - return isNotNull(field(), sourceText(), FIRST).and(isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector")); + return resolveField().and(resolveQuery()).and(resolveMinimumSimilarity()); } private TypeResolution resolveField() { @@ -121,8 +120,8 @@ private TypeResolution resolveField() { } private TypeResolution resolveQuery() { - return isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector").and( - isNotNullAndFoldable(query(), sourceText(), SECOND) + return isNotNull(query(), sourceText(), SECOND).and( + isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector") ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index cab5ec862d7f5..78f63d7be4b67 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware; import org.elasticsearch.xpack.esql.capabilities.TranslationAware; import org.elasticsearch.xpack.esql.common.Failures; @@ -274,6 +275,10 @@ public Expression withFilters(List filterExpressions) { return new Knn(source(), field(), query(), k(), options(), queryBuilder(), filterExpressions); } + public boolean hasNonPushableFilters() { + return filterExpressions().size() > ((KnnVectorQueryBuilder) queryBuilder()).filterQueries().size(); + } + private Map queryOptions() throws InvalidArgumentException { Map options = new HashMap<>(); if (options() != null) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java index 39f37f952ae02..ef33f8529b7bc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.LocalPropagateEmptyRelation; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.LocalSubstituteSurrogateExpressions; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceFieldWithConstantOrNull; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceKnnWithNoPushedDownFiltersWithEvalTopN; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceTopNWithLimitAndSort; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.rule.ParameterizedRuleExecutor; @@ -48,7 +49,8 @@ public class LocalLogicalPlanOptimizer extends ParameterizedRuleExecutor { + + public static final String EXACT_SCORE_ATTR_NAME = "knn_score"; + + public ReplaceKnnWithNoPushedDownFiltersWithEvalTopN() { + super(UP); + } + + @Override + protected LogicalPlan rule(Filter filter) { + Expression condition = filter.condition(); + + Holder> replaced = new Holder<>(new ArrayList<>()); + Expression conditionWithoutKnns = condition.transformDown(Knn.class, knn -> replaceNonPushableKnnByTrue(knn, replaced)); + if (conditionWithoutKnns.equals(condition)) { + return filter; + } + + // Replace knn with scoring expressions of exact queries + List exactQueries = replaced.get() + .stream() + .map(ReplaceKnnWithNoPushedDownFiltersWithEvalTopN::replaceKnnByExact) + .toList(); + int numExactQueries = replaced.get().size(); + assert numExactQueries > 0; + List scoringAliases = new ArrayList<>(numExactQueries); + AttributeMap.Builder aliasesBuilder = AttributeMap.builder(); + for (int i = 0; i < numExactQueries; i++) { + String name = rawTemporaryName(EXACT_SCORE_ATTR_NAME, String.valueOf(i)); + Alias alias = new Alias(EMPTY, name, new Score(EMPTY, exactQueries.get(i))); + scoringAliases.add(alias); + aliasesBuilder.put(alias.toAttribute(), alias.child()); + } + + Eval scoreEval = new Eval( + EMPTY, + filter.with(conditionWithoutKnns), + scoringAliases + ); + + // Filter for all exact scores > 0 + Expression scoreComparison = null; + List scoringAttributes = new ArrayList<>(numExactQueries); + for (int i = 0; i < numExactQueries; i++) { + Attribute scoringAttr = scoringAliases.get(i).toAttribute(); + scoringAttributes.add(scoringAttr); + GreaterThan gt = new GreaterThan( + EMPTY, + scoringAttr, + new Literal(EMPTY, 0.0, DataType.DOUBLE) + ); + if (scoreComparison == null) { + scoreComparison = gt; + } else { + scoreComparison = new And(EMPTY, gt, scoreComparison); + } + } + Filter scoreFilter = new Filter(EMPTY, scoreEval, scoreComparison); + + // Sort on the scores, limit on the minimum k + List orders = new ArrayList<>(numExactQueries); + for (int i = 0; i < numExactQueries; i++) { + orders.add( + new Order( + EMPTY, + scoringAttributes.get(i), + Order.OrderDirection.DESC, + Order.NullsPosition.LAST + ) + ); + } + int minimumK = replaced.get() + .stream() + .map(k -> ((KnnVectorQueryBuilder) k.queryBuilder())) + .mapToInt(KnnVectorQueryBuilder::k) + .min() + .orElseThrow(); + TopN topK = new TopN(EMPTY, scoreFilter, orders, new Literal(EMPTY, minimumK, DataType.INTEGER)); + return topK; +// // Aliases resolution function +// AttributeMap evalAliases = aliasesBuilder.build(); +// topK = (TopN) topK.transformExpressionsOnly(ReferenceAttribute.class, r -> evalAliases.resolve(r, r)); +// return topK; + } + + private static Expression replaceNonPushableKnnByTrue(Knn knn, Holder> replaced) { + if (knn.hasNonPushableFilters() == false) { + return knn; + } + + replaced.get().add(knn); + + return Literal.TRUE; + } + + private static Expression replaceKnnByExact(Knn knn) { + Expression minimumSimilarity = knn.options() == null + ? null + : ((MapExpression) knn.options()).get(VECTOR_SIMILARITY_FIELD.getPreferredName()); + ExactNN exact = new ExactNN( + knn.source(), + knn.field(), + knn.query(), + minimumSimilarity + ); + // Replaces query builder as it was not resolved during post analysis phase + return exact.replaceQueryBuilder( + TranslatorHandler.TRANSLATOR_HANDLER.asQuery(LucenePushdownPredicates.DEFAULT, exact).toQueryBuilder() + ); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/ExactNNQuery.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/ExactNNQuery.java index 9df73c45604fe..d45b10a07b708 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/ExactNNQuery.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/ExactNNQuery.java @@ -46,10 +46,10 @@ public boolean equals(Object o) { if (super.equals(o) == false) return false; if (o == null || getClass() != o.getClass()) return false; - ExactNNQuery knnQuery = (ExactNNQuery) o; - return Objects.equals(field, knnQuery.field) - && Objects.deepEquals(query, knnQuery.query) - && Objects.equals(minimumSimilarity, knnQuery.minimumSimilarity) + ExactNNQuery query = (ExactNNQuery) o; + return Objects.equals(field, query.field) + && Objects.deepEquals(this.query, query.query) + && Objects.equals(minimumSimilarity, query.minimumSimilarity); } @Override From 363e50ece60f00483dfa859c6cc5c9bfe6f59cf2 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 09:05:34 +0200 Subject: [PATCH 03/24] Refactoring and spotless --- .../xpack/esql/plugin/KnnFunctionIT.java | 7 +- .../expression/function/vector/ExactNN.java | 14 +-- ...nnWithNoPushedDownFiltersWithEvalTopN.java | 113 ++++++++---------- 3 files changed, 53 insertions(+), 81 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index 07b9825e9b909..daa9430767072 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -32,6 +32,7 @@ import static org.elasticsearch.index.IndexMode.LOOKUP; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.Matchers.equalTo; @TestLogging(value = "org.elasticsearch.xpack.esql:TRACE", reason = "debug") public class KnnFunctionIT extends AbstractEsqlIntegTestCase { @@ -189,15 +190,13 @@ public void testKnnPrefiltersNotPushedDown() { var query = String.format(Locale.ROOT, """ FROM test | WHERE knn(vector, %s, 5) AND length(keyword) > 5 AND length(keyword) <= 10 - | KEEP id, vector, keyword | SORT id ASC | LIMIT 20 """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "vector", "keyword")); - assertColumnTypes(resp.columns(), List.of("integer", "dense_vector", "keyword")); - + // No added columns + assertThat(resp.columns().size(), equalTo(4)); List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(5, valuesList.size()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java index 0935d3e969af3..7189587647790 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java @@ -79,19 +79,12 @@ public ExactNN( optional = true, description = "The minimum similarity required for a document to be considered a match. " + "The similarity value calculated relates to the raw similarity used, not the document score." - ) - Expression minimumSimilarity + ) Expression minimumSimilarity ) { this(source, field, query, minimumSimilarity, null); } - public ExactNN( - Source source, - Expression field, - Expression query, - Expression minimumSimilarity, - QueryBuilder queryBuilder - ) { + public ExactNN(Source source, Expression field, Expression query, Expression minimumSimilarity, QueryBuilder queryBuilder) { super(source, query, minimumSimilarity == null ? List.of(field, query) : List.of(field, query, minimumSimilarity), queryBuilder); this.field = field; this.minimumSimilarity = minimumSimilarity; @@ -130,8 +123,7 @@ private TypeResolution resolveMinimumSimilarity() { return TypeResolution.TYPE_RESOLVED; } - return isNotNull(minimumSimilarity(), sourceText(), THIRD) - .and(isNumeric(minimumSimilarity(), sourceText(), THIRD)); + return isNotNull(minimumSimilarity(), sourceText(), THIRD).and(isNumeric(minimumSimilarity(), sourceText(), THIRD)); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFiltersWithEvalTopN.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFiltersWithEvalTopN.java index a43eb7324d9aa..5ed06ac931f63 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFiltersWithEvalTopN.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFiltersWithEvalTopN.java @@ -10,7 +10,6 @@ import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; -import org.elasticsearch.xpack.esql.core.expression.AttributeMap; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.MapExpression; @@ -53,45 +52,55 @@ public ReplaceKnnWithNoPushedDownFiltersWithEvalTopN() { protected LogicalPlan rule(Filter filter) { Expression condition = filter.condition(); - Holder> replaced = new Holder<>(new ArrayList<>()); - Expression conditionWithoutKnns = condition.transformDown(Knn.class, knn -> replaceNonPushableKnnByTrue(knn, replaced)); + Holder> knnQueries = new Holder<>(new ArrayList<>()); + Expression conditionWithoutKnns = condition.transformDown(Knn.class, knn -> replaceNonPushableKnnByTrue(knn, knnQueries)); if (conditionWithoutKnns.equals(condition)) { return filter; } // Replace knn with scoring expressions of exact queries - List exactQueries = replaced.get() + List exactQueries = knnQueries.get() .stream() - .map(ReplaceKnnWithNoPushedDownFiltersWithEvalTopN::replaceKnnByExact) + .map(ReplaceKnnWithNoPushedDownFiltersWithEvalTopN::replaceKnnByExactQuery) .toList(); - int numExactQueries = replaced.get().size(); - assert numExactQueries > 0; - List scoringAliases = new ArrayList<>(numExactQueries); - AttributeMap.Builder aliasesBuilder = AttributeMap.builder(); - for (int i = 0; i < numExactQueries; i++) { + assert exactQueries.isEmpty() == false; + + List exactScoreAliases = exactScoreAliases(exactQueries); + Eval scoreEval = new Eval(EMPTY, filter.with(conditionWithoutKnns), exactScoreAliases); + + // Filter for all exact scores > 0 + Filter scoreFilter = exactScoreFilter(exactScoreAliases, scoreEval); + + // Sort on the scores, limit on the minimum k + return topN(exactScoreAliases, knnQueries.get(), scoreFilter); + } + + private static Expression replaceKnnByExactQuery(Knn knn) { + Expression minimumSimilarity = knn.options() == null + ? null + : ((MapExpression) knn.options()).get(VECTOR_SIMILARITY_FIELD.getPreferredName()); + ExactNN exact = new ExactNN(knn.source(), knn.field(), knn.query(), minimumSimilarity); + // Replaces query builder as it was not resolved during post analysis phase + return exact.replaceQueryBuilder( + TranslatorHandler.TRANSLATOR_HANDLER.asQuery(LucenePushdownPredicates.DEFAULT, exact).toQueryBuilder() + ); + } + + private static List exactScoreAliases(List exactQueries) { + List scoringAliases = new ArrayList<>(); + for (int i = 0; i < exactQueries.size(); i++) { String name = rawTemporaryName(EXACT_SCORE_ATTR_NAME, String.valueOf(i)); Alias alias = new Alias(EMPTY, name, new Score(EMPTY, exactQueries.get(i))); scoringAliases.add(alias); - aliasesBuilder.put(alias.toAttribute(), alias.child()); } + return scoringAliases; + } - Eval scoreEval = new Eval( - EMPTY, - filter.with(conditionWithoutKnns), - scoringAliases - ); - - // Filter for all exact scores > 0 + private static Filter exactScoreFilter(List scoreAliases, Eval scoreEval) { Expression scoreComparison = null; - List scoringAttributes = new ArrayList<>(numExactQueries); - for (int i = 0; i < numExactQueries; i++) { - Attribute scoringAttr = scoringAliases.get(i).toAttribute(); - scoringAttributes.add(scoringAttr); - GreaterThan gt = new GreaterThan( - EMPTY, - scoringAttr, - new Literal(EMPTY, 0.0, DataType.DOUBLE) - ); + for (Alias scoreAlias : scoreAliases) { + Attribute scoringAttr = scoreAlias.toAttribute(); + GreaterThan gt = new GreaterThan(EMPTY, scoringAttr, new Literal(EMPTY, 0.0, DataType.DOUBLE)); if (scoreComparison == null) { scoreComparison = gt; } else { @@ -99,31 +108,7 @@ protected LogicalPlan rule(Filter filter) { } } Filter scoreFilter = new Filter(EMPTY, scoreEval, scoreComparison); - - // Sort on the scores, limit on the minimum k - List orders = new ArrayList<>(numExactQueries); - for (int i = 0; i < numExactQueries; i++) { - orders.add( - new Order( - EMPTY, - scoringAttributes.get(i), - Order.OrderDirection.DESC, - Order.NullsPosition.LAST - ) - ); - } - int minimumK = replaced.get() - .stream() - .map(k -> ((KnnVectorQueryBuilder) k.queryBuilder())) - .mapToInt(KnnVectorQueryBuilder::k) - .min() - .orElseThrow(); - TopN topK = new TopN(EMPTY, scoreFilter, orders, new Literal(EMPTY, minimumK, DataType.INTEGER)); - return topK; -// // Aliases resolution function -// AttributeMap evalAliases = aliasesBuilder.build(); -// topK = (TopN) topK.transformExpressionsOnly(ReferenceAttribute.class, r -> evalAliases.resolve(r, r)); -// return topK; + return scoreFilter; } private static Expression replaceNonPushableKnnByTrue(Knn knn, Holder> replaced) { @@ -136,19 +121,15 @@ private static Expression replaceNonPushableKnnByTrue(Knn knn, Holder> return Literal.TRUE; } - private static Expression replaceKnnByExact(Knn knn) { - Expression minimumSimilarity = knn.options() == null - ? null - : ((MapExpression) knn.options()).get(VECTOR_SIMILARITY_FIELD.getPreferredName()); - ExactNN exact = new ExactNN( - knn.source(), - knn.field(), - knn.query(), - minimumSimilarity - ); - // Replaces query builder as it was not resolved during post analysis phase - return exact.replaceQueryBuilder( - TranslatorHandler.TRANSLATOR_HANDLER.asQuery(LucenePushdownPredicates.DEFAULT, exact).toQueryBuilder() - ); + private static TopN topN(List scoreAliases, List knnQueries, Filter scoreFilter) { + List orders = scoreAliases.stream() + .map(a -> new Order(EMPTY, a.toAttribute(), Order.OrderDirection.DESC, Order.NullsPosition.LAST)) + .toList(); + int minimumK = knnQueries.stream() + .map(k -> ((KnnVectorQueryBuilder) k.queryBuilder())) + .mapToInt(KnnVectorQueryBuilder::k) + .min() + .orElseThrow(); + return new TopN(EMPTY, scoreFilter, orders, new Literal(EMPTY, minimumK, DataType.INTEGER)); } } From a60aa3ddeab93aa895d284511337cf5c6d707e5e Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 09:50:55 +0200 Subject: [PATCH 04/24] Add _score use case --- .../xpack/esql/plugin/KnnFunctionIT.java | 20 ++++++ ...nnWithNoPushedDownFiltersWithEvalTopN.java | 64 +++++++++++-------- 2 files changed, 59 insertions(+), 25 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index daa9430767072..3788fdf5e10a9 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -202,6 +202,26 @@ public void testKnnPrefiltersNotPushedDown() { } } + public void testKnnPrefiltersNotPushedDownWithScoring() { + float[] queryVector = new float[numDims]; + Arrays.fill(queryVector, 1.0f); + + // We retrieve 5 from knn, but must be prefiltered with the non-pushed down conjunction + var query = String.format(Locale.ROOT, """ + FROM test METADATA _score + | WHERE knn(vector, %s, 5) AND length(keyword) > 5 AND length(keyword) <= 10 + | SORT id ASC + | LIMIT 20 + """, Arrays.toString(queryVector)); + + try (var resp = run(query)) { + // No added columns + assertThat(resp.columns().size(), equalTo(5)); + List> valuesList = EsqlTestUtils.getValuesList(resp); + assertEquals(5, valuesList.size()); + } + } + @Before public void setup() throws IOException { assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFiltersWithEvalTopN.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFiltersWithEvalTopN.java index 5ed06ac931f63..5de3e42c852d5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFiltersWithEvalTopN.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFiltersWithEvalTopN.java @@ -13,6 +13,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.MapExpression; +import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.Order; @@ -36,6 +37,7 @@ import static org.elasticsearch.xpack.esql.core.expression.Attribute.rawTemporaryName; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.UP; +import static org.elasticsearch.xpack.esql.planner.PlannerUtils.usesScoring; /** * Break TopN back into Limit + OrderBy to allow the order rules to kick in. @@ -58,21 +60,36 @@ protected LogicalPlan rule(Filter filter) { return filter; } - // Replace knn with scoring expressions of exact queries - List exactQueries = knnQueries.get() - .stream() - .map(ReplaceKnnWithNoPushedDownFiltersWithEvalTopN::replaceKnnByExactQuery) - .toList(); - assert exactQueries.isEmpty() == false; - - List exactScoreAliases = exactScoreAliases(exactQueries); - Eval scoreEval = new Eval(EMPTY, filter.with(conditionWithoutKnns), exactScoreAliases); + List scoreAttrs; + LogicalPlan scoringPlan; + if (usesScoring(filter)) { + scoreAttrs = filter.output() + .stream() + .filter(attr -> attr instanceof MetadataAttribute ma && ma.name().equals(MetadataAttribute.SCORE)) + .toList(); + // Use the original filter, changing knn to exact queries + scoringPlan = filter.with( + filter.condition().transformDown(Knn.class, ReplaceKnnWithNoPushedDownFiltersWithEvalTopN::replaceKnnByExactQuery) + ); + } else { + // Replace knn with scoring expressions of exact queries + List exactQueries = knnQueries.get() + .stream() + .map(ReplaceKnnWithNoPushedDownFiltersWithEvalTopN::replaceKnnByExactQuery) + .toList(); + assert exactQueries.isEmpty() == false; + + // Create an Eval for scoring the exact queries + List exactScoreAliases = exactQueryScoreAliases(exactQueries); + scoringPlan = new Eval(EMPTY, filter.with(conditionWithoutKnns), exactScoreAliases); + scoreAttrs = exactScoreAliases.stream().map(Alias::toAttribute).toList(); + } - // Filter for all exact scores > 0 - Filter scoreFilter = exactScoreFilter(exactScoreAliases, scoreEval); + // Sort on the scores, limit on the minimum k from the queries + TopN topN = createTopN(scoreAttrs, knnQueries.get(), scoringPlan); - // Sort on the scores, limit on the minimum k - return topN(exactScoreAliases, knnQueries.get(), scoreFilter); + // Filter on scores > 0. We could filter earlier, but could be combined with the existing filter and _score would not be updated + return createScoreFilter(scoreAttrs, topN); } private static Expression replaceKnnByExactQuery(Knn knn) { @@ -86,7 +103,7 @@ private static Expression replaceKnnByExactQuery(Knn knn) { ); } - private static List exactScoreAliases(List exactQueries) { + private static List exactQueryScoreAliases(List exactQueries) { List scoringAliases = new ArrayList<>(); for (int i = 0; i < exactQueries.size(); i++) { String name = rawTemporaryName(EXACT_SCORE_ATTR_NAME, String.valueOf(i)); @@ -96,10 +113,9 @@ private static List exactScoreAliases(List exactQueries) { return scoringAliases; } - private static Filter exactScoreFilter(List scoreAliases, Eval scoreEval) { + private static Filter createScoreFilter(List scoreAttrs, LogicalPlan planToFilter) { Expression scoreComparison = null; - for (Alias scoreAlias : scoreAliases) { - Attribute scoringAttr = scoreAlias.toAttribute(); + for (Attribute scoringAttr : scoreAttrs) { GreaterThan gt = new GreaterThan(EMPTY, scoringAttr, new Literal(EMPTY, 0.0, DataType.DOUBLE)); if (scoreComparison == null) { scoreComparison = gt; @@ -107,29 +123,27 @@ private static Filter exactScoreFilter(List scoreAliases, Eval scoreEval) scoreComparison = new And(EMPTY, gt, scoreComparison); } } - Filter scoreFilter = new Filter(EMPTY, scoreEval, scoreComparison); - return scoreFilter; + + return new Filter(EMPTY, planToFilter, scoreComparison); } private static Expression replaceNonPushableKnnByTrue(Knn knn, Holder> replaced) { if (knn.hasNonPushableFilters() == false) { return knn; } - replaced.get().add(knn); - return Literal.TRUE; } - private static TopN topN(List scoreAliases, List knnQueries, Filter scoreFilter) { - List orders = scoreAliases.stream() - .map(a -> new Order(EMPTY, a.toAttribute(), Order.OrderDirection.DESC, Order.NullsPosition.LAST)) + private static TopN createTopN(List scoreAttrs, List knnQueries, LogicalPlan scoringPlan) { + List orders = scoreAttrs.stream() + .map(a -> new Order(EMPTY, a, Order.OrderDirection.DESC, Order.NullsPosition.LAST)) .toList(); int minimumK = knnQueries.stream() .map(k -> ((KnnVectorQueryBuilder) k.queryBuilder())) .mapToInt(KnnVectorQueryBuilder::k) .min() .orElseThrow(); - return new TopN(EMPTY, scoreFilter, orders, new Literal(EMPTY, minimumK, DataType.INTEGER)); + return new TopN(EMPTY, scoringPlan, orders, new Literal(EMPTY, minimumK, DataType.INTEGER)); } } From 63c62c7e3e840b7d217a38a721096987cf11cb2e Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 11:50:18 +0200 Subject: [PATCH 05/24] Check knn usage in disjunctions for non pushable filters --- ...PostOptimizationPlanVerificationAware.java | 27 ++++++++++ .../esql/expression/function/vector/Knn.java | 50 ++++++++++++++++--- .../xpack/esql/optimizer/LogicalVerifier.java | 5 ++ ...nnWithNoPushedDownFiltersWithEvalTopN.java | 2 +- .../optimizer/LogicalPlanOptimizerTests.java | 33 ++++++++++++ 5 files changed, 109 insertions(+), 8 deletions(-) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/capabilities/PostOptimizationPlanVerificationAware.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/capabilities/PostOptimizationPlanVerificationAware.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/capabilities/PostOptimizationPlanVerificationAware.java new file mode 100644 index 0000000000000..1ff093e3fdc36 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/capabilities/PostOptimizationPlanVerificationAware.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.capabilities; + +import org.elasticsearch.xpack.esql.common.Failures; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +import java.util.function.BiConsumer; + +/** + * Interface implemented by expressions or plans that require validation post logical optimization, + * when the plan and references have been not just resolved but also replaced. + * The interface is similar to {@link PostOptimizationVerificationAware}, but focused on the tree structure + */ +public interface PostOptimizationPlanVerificationAware { + + /** + * Allows the implementer to return a consumer that will perform self-validation in the context of the tree structure the implementer + * is part of. This usually involves checking the type and configuration of the children or that of the parent. + */ + BiConsumer postOptimizationVerification(); +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index 78f63d7be4b67..29ed31f908132 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -11,8 +11,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; -import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware; +import org.elasticsearch.xpack.esql.capabilities.PostOptimizationPlanVerificationAware; import org.elasticsearch.xpack.esql.capabilities.TranslationAware; import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.InvalidArgumentException; @@ -35,25 +34,30 @@ import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; +import org.elasticsearch.xpack.esql.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; +import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery; import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.function.BiConsumer; +import java.util.stream.Collectors; import static java.util.Map.entry; import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD; import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD; import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD; +import static org.elasticsearch.xpack.esql.common.Failure.fail; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FOURTH; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; @@ -66,7 +70,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; -public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction, PostAnalysisPlanVerificationAware { +public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction, PostOptimizationPlanVerificationAware { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom); @@ -260,8 +264,8 @@ protected Query translate(LucenePushdownPredicates pushdownPredicates, Translato for (Expression filterExpression : filterExpressions()) { if (filterExpression instanceof TranslationAware translationAware) { // We can only translate filter expressions that are translatable. In case any is not translatable, - // Knn won't be pushed down as it will not be translatable so it's safe not to translate all filters and check them - // when creating an evaluator for the non-pushed down query + // Knn won't be pushed down so it's safe not to translate all filters and check them when creating an evaluator + // for the non-pushed down query if (translationAware.translatable(pushdownPredicates) == Translatable.YES) { filterQueries.add(handler.asQuery(pushdownPredicates, filterExpression).toQueryBuilder()); } @@ -275,8 +279,16 @@ public Expression withFilters(List filterExpressions) { return new Knn(source(), field(), query(), k(), options(), queryBuilder(), filterExpressions); } - public boolean hasNonPushableFilters() { - return filterExpressions().size() > ((KnnVectorQueryBuilder) queryBuilder()).filterQueries().size(); + public Collection nonPushableFilters() { + List nonPushableFilters = new ArrayList<>(); + for (Expression filterExpression : filterExpressions()) { + if (filterExpression instanceof TranslationAware translationAware) { + if (translationAware.translatable(LucenePushdownPredicates.DEFAULT) == Translatable.NO) { + nonPushableFilters.add(filterExpression); + } + } + } + return nonPushableFilters; } private Map queryOptions() throws InvalidArgumentException { @@ -295,6 +307,30 @@ public BiConsumer postAnalysisPlanVerification() { }; } + @Override + public BiConsumer postOptimizationVerification() { + return (plan, failures) -> { + if (plan instanceof Filter f) { + f.condition().forEachDown(Or.class, or -> { + or.forEachDown(Knn.class, knn -> { + Collection nonPushableFilters = knn.nonPushableFilters(); + if (nonPushableFilters.isEmpty() == false) { + failures.add( + fail( + plan, + "knn function [{}] cannot be used in an OR clause when it is being filtered with " + + "the following AND conditions: {}.", + knn.sourceText(), + nonPushableFilters.stream().map(Expression::sourceText).collect(Collectors.joining(", ")) + ) + ); + } + }); + }); + } + }; + } + @Override public Expression replaceChildren(List newChildren) { return new Knn( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java index 6751ae4cd2d80..09d5191c9f502 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.optimizer; +import org.elasticsearch.xpack.esql.capabilities.PostOptimizationPlanVerificationAware; import org.elasticsearch.xpack.esql.capabilities.PostOptimizationVerificationAware; import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.optimizer.rules.PlanConsistencyChecker; @@ -38,10 +39,14 @@ public Failures verify(LogicalPlan plan, boolean skipRemoteEnrichVerification) { if (failures.hasFailures() == false) { if (p instanceof PostOptimizationVerificationAware pova) { pova.postOptimizationVerification(failures); + } else if (p instanceof PostOptimizationPlanVerificationAware popva) { + popva.postOptimizationVerification().accept(p, failures); } p.forEachExpression(ex -> { if (ex instanceof PostOptimizationVerificationAware va) { va.postOptimizationVerification(failures); + } else if (ex instanceof PostOptimizationPlanVerificationAware pva) { + pva.postOptimizationVerification().accept(p, failures); } }); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFiltersWithEvalTopN.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFiltersWithEvalTopN.java index 5de3e42c852d5..0d19db5520c16 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFiltersWithEvalTopN.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFiltersWithEvalTopN.java @@ -128,7 +128,7 @@ private static Filter createScoreFilter(List scoreAttrs, LogicalPlan } private static Expression replaceNonPushableKnnByTrue(Knn knn, Holder> replaced) { - if (knn.hasNonPushableFilters() == false) { + if (knn.nonPushableFilters().isEmpty()) { return knn; } replaced.get().add(knn); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index e301c1610bd7b..593cb56bec244 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -8026,4 +8026,37 @@ public void testMultipleKnnQueriesInPrefilters() { assertThat(secondKnnFilters.size(), equalTo(1)); assertTrue(secondKnnFilters.contains(firstOr.right())); } + + public void testKnnInDisjunctions() { + assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + // Disjunctions with pushable conditions are allowed + planTypes("from types | where (knn(dense_vector, [0.1, 0.2, 0.3], 10) or match(text, \"hello\")) " + "and keyword == \"prod\""); + planTypes( + "from types | where ((knn(dense_vector, [0.1, 0.2, 0.3], 10) and match(text, \"hello\")) or keyword == \"hello\")" + + "and (keyword ==\"prod\" or long == 50)" + ); + + // Disjunctions with non-pushable conditions as a prefilter must fail + assertThat( + typesError( + "from types | where (knn(dense_vector, [0.1, 0.2, 0.3], 10) or match(text, \"hello\")) " + "and length(keyword) > 10" + ), + containsString( + "knn function [knn(dense_vector, [0.1, 0.2, 0.3], 10)] cannot be used in an OR clause " + + "when it is being filtered with the following AND conditions: length(keyword) > 10." + ) + ); + + assertThat( + typesError( + "from types | where ((knn(dense_vector, [0.1, 0.2, 0.3], 10) and match(text, \"hello\")) or keyword == \"hello\")" + + "and (length(keyword) > 10 or long == 50)" + ), + containsString( + "knn function [knn(dense_vector, [0.1, 0.2, 0.3], 10)] cannot be used in an OR clause " + + "when it is being filtered with the following AND conditions: length(keyword) > 10 or long == 50." + ) + ); + } } From 4440717074b924759dcc65fc84a06c27958388ea Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 11:51:59 +0200 Subject: [PATCH 06/24] Rename --- .../xpack/esql/optimizer/LocalLogicalPlanOptimizer.java | 4 ++-- ...alTopN.java => ReplaceKnnWithNoPushedDownFilters.java} | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) rename x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/{ReplaceKnnWithNoPushedDownFiltersWithEvalTopN.java => ReplaceKnnWithNoPushedDownFilters.java} (95%) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java index ef33f8529b7bc..4d0815db7c5e3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java @@ -17,7 +17,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.LocalPropagateEmptyRelation; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.LocalSubstituteSurrogateExpressions; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceFieldWithConstantOrNull; -import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceKnnWithNoPushedDownFiltersWithEvalTopN; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceKnnWithNoPushedDownFilters; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceTopNWithLimitAndSort; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.rule.ParameterizedRuleExecutor; @@ -50,7 +50,7 @@ public class LocalLogicalPlanOptimizer extends ParameterizedRuleExecutor { +public class ReplaceKnnWithNoPushedDownFilters extends OptimizerRules.OptimizerRule { public static final String EXACT_SCORE_ATTR_NAME = "knn_score"; - public ReplaceKnnWithNoPushedDownFiltersWithEvalTopN() { + public ReplaceKnnWithNoPushedDownFilters() { super(UP); } @@ -69,13 +69,13 @@ protected LogicalPlan rule(Filter filter) { .toList(); // Use the original filter, changing knn to exact queries scoringPlan = filter.with( - filter.condition().transformDown(Knn.class, ReplaceKnnWithNoPushedDownFiltersWithEvalTopN::replaceKnnByExactQuery) + filter.condition().transformDown(Knn.class, ReplaceKnnWithNoPushedDownFilters::replaceKnnByExactQuery) ); } else { // Replace knn with scoring expressions of exact queries List exactQueries = knnQueries.get() .stream() - .map(ReplaceKnnWithNoPushedDownFiltersWithEvalTopN::replaceKnnByExactQuery) + .map(ReplaceKnnWithNoPushedDownFilters::replaceKnnByExactQuery) .toList(); assert exactQueries.isEmpty() == false; From dee1e917259988dcd3c1f8e67d5ba93ffa27dde3 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 13:28:44 +0200 Subject: [PATCH 07/24] Move ReplaceKnnWithNoPushedDownFilters to logical optimizer --- .../esql/optimizer/LocalLogicalPlanOptimizer.java | 4 +--- .../xpack/esql/optimizer/LogicalPlanOptimizer.java | 6 ++++-- .../ReplaceKnnWithNoPushedDownFilters.java | 11 +++-------- 3 files changed, 8 insertions(+), 13 deletions(-) rename x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/{local => }/ReplaceKnnWithNoPushedDownFilters.java (94%) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java index 4d0815db7c5e3..39f37f952ae02 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java @@ -17,7 +17,6 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.LocalPropagateEmptyRelation; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.LocalSubstituteSurrogateExpressions; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceFieldWithConstantOrNull; -import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceKnnWithNoPushedDownFilters; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceTopNWithLimitAndSort; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.rule.ParameterizedRuleExecutor; @@ -49,8 +48,7 @@ public class LocalLogicalPlanOptimizer extends ParameterizedRuleExecutor substitutions() { new ReplaceAliasingEvalWithProject(), new SkipQueryOnEmptyMappings(), new SubstituteSurrogateExpressions(), - new ReplaceOrderByExpressionWithEval() + new ReplaceOrderByExpressionWithEval(), + new PushDownConjunctionsToKnnPrefilters(), + new ReplaceKnnWithNoPushedDownFilters() // new NormalizeAggregate(), - waits on https://github.com/elastic/elasticsearch/issues/100634 ); } @@ -193,7 +196,6 @@ protected static Batch operators(boolean local) { new PruneLiteralsInOrderBy(), new PushDownAndCombineLimits(), new PushDownAndCombineFilters(), - new PushDownConjunctionsToKnnPrefilters(), new PushDownAndCombineSample(), new PushDownInferencePlan(), new PushDownEval(), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java similarity index 94% rename from x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFilters.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java index 9beebab80248b..6b1de8f9ae838 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceKnnWithNoPushedDownFilters.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java @@ -5,12 +5,12 @@ * 2.0. */ -package org.elasticsearch.xpack.esql.optimizer.rules.logical.local; +package org.elasticsearch.xpack.esql.optimizer.rules.logical; -import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.MapExpression; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; @@ -22,7 +22,6 @@ import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; -import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Filter; @@ -139,11 +138,7 @@ private static TopN createTopN(List scoreAttrs, List knnQueries, List orders = scoreAttrs.stream() .map(a -> new Order(EMPTY, a, Order.OrderDirection.DESC, Order.NullsPosition.LAST)) .toList(); - int minimumK = knnQueries.stream() - .map(k -> ((KnnVectorQueryBuilder) k.queryBuilder())) - .mapToInt(KnnVectorQueryBuilder::k) - .min() - .orElseThrow(); + int minimumK = knnQueries.stream().mapToInt(knn -> (Integer) knn.k().fold(FoldContext.small())).min().orElseThrow(); return new TopN(EMPTY, scoringPlan, orders, new Literal(EMPTY, minimumK, DataType.INTEGER)); } } From 6438bc17711abe26fa3387fb060f52b90e4319dc Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 13:28:54 +0200 Subject: [PATCH 08/24] Add tests --- .../optimizer/LogicalPlanOptimizerTests.java | 98 +++++++++++++++++-- 1 file changed, 92 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 593cb56bec244..bc19e7bf6cd2f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.MapExpression; +import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; @@ -58,6 +59,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.expression.function.fulltext.MultiMatch; +import org.elasticsearch.xpack.esql.expression.function.fulltext.Score; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble; @@ -74,6 +76,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; +import org.elasticsearch.xpack.esql.expression.function.vector.ExactNN; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.expression.predicate.logical.Not; @@ -179,6 +182,7 @@ import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GTE; import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LT; import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LTE; +import static org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceKnnWithNoPushedDownFilters.EXACT_SCORE_ATTR_NAME; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.contains; @@ -8027,15 +8031,97 @@ public void testMultipleKnnQueriesInPrefilters() { assertTrue(secondKnnFilters.contains(firstOr.right())); } - public void testKnnInDisjunctions() { + public void testKnnInWithNonPushablePrefiltersNoScoring() { assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); // Disjunctions with pushable conditions are allowed - planTypes("from types | where (knn(dense_vector, [0.1, 0.2, 0.3], 10) or match(text, \"hello\")) " + "and keyword == \"prod\""); - planTypes( - "from types | where ((knn(dense_vector, [0.1, 0.2, 0.3], 10) and match(text, \"hello\")) or keyword == \"hello\")" - + "and (keyword ==\"prod\" or long == 50)" - ); + var plan = planTypes(""" + from types + | where knn(dense_vector, [0.1, 0.2, 0.3], 5) and match(text, "hello") and length(keyword) > 10 + """); + + var limit = as(plan, Limit.class); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); + + // Next: Filter[$$knn_score$0 > 0.0] + var filter = as(limit.child(), Filter.class); + var gt = as(filter.condition(), GreaterThan.class); + ReferenceAttribute scoreAttr = as(gt.left(), ReferenceAttribute.class); + assertThat(scoreAttr.toString(), containsString(EXACT_SCORE_ATTR_NAME)); + assertThat(gt.right().fold(FoldContext.small()), equalTo(0.0)); + + // Next: TopN[..., 10] + var topN = as(filter.child(), TopN.class); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(5)); + assertThat(topN.order().getFirst().child(), equalTo(scoreAttr)); + + // Next: Eval[SCORE(EXACTNN(...)) AS $$knn_score$...] + var eval = as(topN.child(), Eval.class); + var alias = as(eval.fields().get(0), Alias.class); + assertThat(alias.name(), equalTo(scoreAttr.name())); + var score = as(alias.child(), Score.class); + var exactNN = as(score.children().getFirst(), ExactNN.class); + var field = as(exactNN.field(), FieldAttribute.class); + assertThat(field.name(), equalTo("dense_vector")); + assertThat(exactNN.query().toString(), equalTo("[0.1, 0.2, 0.3]")); + + var prefilter = as(eval.child(), Filter.class); + var and = as(prefilter.condition(), And.class); + as(and.left(), Match.class); + var lenGt = as(and.right(), GreaterThan.class); + assertThat(Expressions.name(lenGt.left()), containsString("length(keyword)")); + assertThat(lenGt.right().fold(FoldContext.small()), equalTo(10)); + + // Next: EsRelation[types] + var esRelation = as(prefilter.child(), EsRelation.class); + } + + public void testKnnInWithNonPushablePrefiltersScoring() { + assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + // Disjunctions with pushable conditions are allowed + var plan = planTypes(""" + from types metadata _score + | where knn(dense_vector, [0.1, 0.2, 0.3], 5) and match(text, "hello") and length(keyword) > 10 + """); + + // Top-level: Limit[1000[INTEGER],false] + var limit = as(plan, Limit.class); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); + + // Next: Filter[_score > 0.0] + var filter = as(limit.child(), Filter.class); + var gt = as(filter.condition(), GreaterThan.class); + MetadataAttribute scoreAttr = as(gt.left(), MetadataAttribute.class); + assertThat(scoreAttr.name(), equalTo("_score")); + assertThat(gt.right().fold(FoldContext.small()), equalTo(0.0)); + + // Next: TopN[..., 5] + var topN = as(filter.child(), TopN.class); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(5)); + assertThat(topN.order().getFirst().child(), equalTo(scoreAttr)); + + // Next: Filter[EXACTNN(...) AND MATCH(...) AND LENGTH(keyword) > 10] + var innerFilter = as(topN.child(), Filter.class); + var and1 = as(innerFilter.condition(), And.class); + var and2 = as(and1.left(), And.class); + + var exactNN = as(and2.left(), ExactNN.class); + var field = as(exactNN.field(), FieldAttribute.class); + assertThat(field.name(), equalTo("dense_vector")); + assertThat(exactNN.query().toString(), equalTo("[0.1, 0.2, 0.3]")); + var match = as(and2.right(), Match.class); + var lenGt = as(and1.right(), GreaterThan.class); + + assertThat(Expressions.name(lenGt.left()), containsString("length(keyword)")); + assertThat(lenGt.right().fold(FoldContext.small()), equalTo(10)); + + // Next: EsRelation[types] + var esRelation = as(innerFilter.child(), EsRelation.class); + } + + public void testKnnInDisjunctionsWithNonPushablePrefilters() { + assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); // Disjunctions with non-pushable conditions as a prefilter must fail assertThat( From 21bb1f146ce4be3167710b34b00f0abea0ab7e24 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 13:29:37 +0200 Subject: [PATCH 09/24] Fix knn refs --- .../expression/function/vector/ExactNN.java | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java index 7189587647790..347e71df5098a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java @@ -53,7 +53,11 @@ */ public class ExactNN extends FullTextFunction implements OptionalArgument, VectorFunction, PostAnalysisPlanVerificationAware { - public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", ExactNN::readFrom); + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "ExactNN", + ExactNN::readFrom + ); private final Expression field; private final Expression minimumSimilarity; @@ -197,14 +201,14 @@ public void writeTo(StreamOutput out) throws IOException { @Override public boolean equals(Object o) { - // Knn does not serialize options, as they get included in the query builder. We need to override equals and hashcode to - // ignore options when comparing two Knn functions + // ExactNN does not serialize options, as they get included in the query builder. We need to override equals and hashcode to + // ignore options when comparing two ExactNN functions if (o == null || getClass() != o.getClass()) return false; - ExactNN knn = (ExactNN) o; - return Objects.equals(field(), knn.field()) - && Objects.equals(query(), knn.query()) - && Objects.equals(minimumSimilarity(), knn.minimumSimilarity()) - && Objects.equals(queryBuilder(), knn.queryBuilder()); + ExactNN exact = (ExactNN) o; + return Objects.equals(field(), exact.field()) + && Objects.equals(query(), exact.query()) + && Objects.equals(minimumSimilarity(), exact.minimumSimilarity()) + && Objects.equals(queryBuilder(), exact.queryBuilder()); } @Override From 376be41ee5f5a3199cf3e87600f25fb07e2ef8e5 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 13:45:45 +0200 Subject: [PATCH 10/24] Add tests --- .../optimizer/LogicalPlanOptimizerTests.java | 128 +++++++++++++++++- 1 file changed, 126 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index bc19e7bf6cd2f..7cf8f5e18d406 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -8031,7 +8031,7 @@ public void testMultipleKnnQueriesInPrefilters() { assertTrue(secondKnnFilters.contains(firstOr.right())); } - public void testKnnInWithNonPushablePrefiltersNoScoring() { + public void testKnnWithNonPushablePrefiltersNoScoring() { assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); // Disjunctions with pushable conditions are allowed @@ -8076,7 +8076,73 @@ public void testKnnInWithNonPushablePrefiltersNoScoring() { var esRelation = as(prefilter.child(), EsRelation.class); } - public void testKnnInWithNonPushablePrefiltersScoring() { + public void testKnnWithNonPushablePrefiltersNoScoringMultipleKnn() { + assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + // Disjunctions with pushable conditions are allowed + var plan = planTypes(""" + from types + | where knn(dense_vector, [0.1, 0.2, 0.3], 5) and knn(dense_vector, [0.4, 0.5, 0.6], 7) + and match(text, "hello") and length(keyword) > 10 + """); + + var limit = as(plan, Limit.class); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); + + // Next: Filter[$$knn_score$1 > 0.0 AND $$knn_score$0 > 0.0] + var filter = as(limit.child(), Filter.class); + var and = as(filter.condition(), And.class); + + // Both sides are GreaterThan for the two score attrs + var gt1 = as(and.left(), GreaterThan.class); + var gt2 = as(and.right(), GreaterThan.class); + + ReferenceAttribute scoreAttr0 = as(gt1.left(), ReferenceAttribute.class); + assertThat(scoreAttr0.name(), containsString(EXACT_SCORE_ATTR_NAME)); + assertThat(gt1.right().fold(FoldContext.small()), equalTo(0.0)); + ReferenceAttribute scoreAttr1 = as(gt2.left(), ReferenceAttribute.class); + assertThat(scoreAttr1.name(), containsString(EXACT_SCORE_ATTR_NAME)); + assertThat(gt2.right().fold(FoldContext.small()), equalTo(0.0)); + + // Next: TopN[..., 5] + var topN = as(filter.child(), TopN.class); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(5)); + assertThat(topN.order().size(), equalTo(2)); + assertThat(topN.order().get(0).child(), equalTo(scoreAttr1)); + assertThat(topN.order().get(1).child(), equalTo(scoreAttr0)); + + // Next: Eval[SCORE(EXACTNN(...)) AS $$knn_score$0, ...] + var eval = as(topN.child(), Eval.class); + assertThat(eval.fields().size(), equalTo(2)); + var alias0 = as(eval.fields().get(0), Alias.class); + assertThat(alias0.name(), equalTo(scoreAttr1.name())); + var score0 = as(alias0.child(), Score.class); + var exactNN0 = as(score0.children().getFirst(), ExactNN.class); + var field0 = as(exactNN0.field(), FieldAttribute.class); + assertThat(field0.name(), equalTo("dense_vector")); + assertThat(exactNN0.query().toString(), equalTo("[0.1, 0.2, 0.3]")); + + var alias1 = as(eval.fields().get(1), Alias.class); + assertThat(alias1.name(), equalTo(scoreAttr0.name())); + var score1 = as(alias1.child(), Score.class); + var exactNN1 = as(score1.children().getFirst(), ExactNN.class); + var field1 = as(exactNN1.field(), FieldAttribute.class); + assertThat(field1.name(), equalTo("dense_vector")); + assertThat(exactNN1.query().toString(), equalTo("[0.4, 0.5, 0.6]")); + + // Next: Filter[MATCH(...) AND LENGTH(keyword) > 10] + var prefilter = as(eval.child(), Filter.class); + var andPref = as(prefilter.condition(), And.class); + as(andPref.left(), Match.class); + var lenGt = as(andPref.right(), GreaterThan.class); + assertThat(Expressions.name(lenGt.left()), containsString("length(keyword)")); + assertThat(lenGt.right().fold(FoldContext.small()), equalTo(10)); + + // Next: EsRelation[types] + var esRelation = as(prefilter.child(), EsRelation.class); + } + + public void testKnnWithNonPushablePrefiltersScoring() { assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); // Disjunctions with pushable conditions are allowed @@ -8120,6 +8186,64 @@ public void testKnnInWithNonPushablePrefiltersScoring() { var esRelation = as(innerFilter.child(), EsRelation.class); } + public void testKnnWithNonPushablePrefiltersScoringMultipleKnns() { + assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + // Disjunctions with pushable conditions are allowed + var plan = planTypes(""" + from types metadata _score + | where knn(dense_vector, [0.1, 0.2, 0.3], 5) and match(text, "hello") and length(keyword) > 10 + and knn(dense_vector, [0.3, 0.4, 0.5], 7) + """); + + // Top-level: Limit[1000[INTEGER],false] + var limit = as(plan, Limit.class); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); + + // Next: Filter[_score > 0.0] + var filter = as(limit.child(), Filter.class); + var gt = as(filter.condition(), GreaterThan.class); + MetadataAttribute scoreAttr = as(gt.left(), MetadataAttribute.class); + assertThat(scoreAttr.name(), equalTo("_score")); + assertThat(gt.right().fold(FoldContext.small()), equalTo(0.0)); + + // Next: TopN[..., 5] + var topN = as(filter.child(), TopN.class); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(5)); + assertThat(topN.order().getFirst().child(), equalTo(scoreAttr)); + + // Next: Filter[EXACTNN(...) AND MATCH(...) AND LENGTH(keyword) > 10 AND EXACTNN(...)] + var innerFilter = as(topN.child(), Filter.class); + // Should be a chain of ANDs + var and1 = as(innerFilter.condition(), And.class); + var and2 = as(and1.left(), And.class); + var and3 = as(and2.left(), And.class); + + // First EXACTNN + var exactNN1 = as(and3.left(), ExactNN.class); + var field1 = as(exactNN1.field(), FieldAttribute.class); + assertThat(field1.name(), equalTo("dense_vector")); + assertThat(exactNN1.query().toString(), equalTo("[0.1, 0.2, 0.3]")); + + // MATCH + var match = as(and3.right(), Match.class); + + // LENGTH(keyword) > 10 + var lenGt = as(and2.right(), GreaterThan.class); + assertThat(Expressions.name(lenGt.left()), containsString("length(keyword)")); + assertThat(lenGt.right().fold(FoldContext.small()), equalTo(10)); + + // Second EXACTNN + var exactNN2 = as(and1.right(), ExactNN.class); + var field2 = as(exactNN2.field(), FieldAttribute.class); + assertThat(field2.name(), equalTo("dense_vector")); + assertThat(exactNN2.query().toString(), equalTo("[0.3, 0.4, 0.5]")); + + // Next: EsRelation[types] + var esRelation = as(innerFilter.child(), EsRelation.class); + + } + public void testKnnInDisjunctionsWithNonPushablePrefilters() { assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); From a24645d9aec90d059ed945777dd00f41a8d7599b Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 13:46:37 +0200 Subject: [PATCH 11/24] capability bump --- .../xpack/esql/plugin/KnnFunctionIT.java | 2 +- .../xpack/esql/action/EsqlCapabilities.java | 2 +- .../function/vector/VectorWritables.java | 2 +- .../elasticsearch/xpack/esql/CsvTests.java | 2 +- .../xpack/esql/analysis/VerifierTests.java | 18 +++++++-------- .../function/fulltext/KnnTests.java | 2 +- .../LocalPhysicalPlanOptimizerTests.java | 18 +++++++-------- .../optimizer/LogicalPlanOptimizerTests.java | 22 +++++++++---------- 8 files changed, 34 insertions(+), 34 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index 3788fdf5e10a9..5f38f2fcfdf8b 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -224,7 +224,7 @@ public void testKnnPrefiltersNotPushedDownWithScoring() { @Before public void setup() throws IOException { - assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); var indexName = "test"; var client = client().admin().indices(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 5c97879dd6a6d..187724cb44f8a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1223,7 +1223,7 @@ public enum Cap { /** * Support knn function */ - KNN_FUNCTION_V3(Build.current().isSnapshot()), + KNN_FUNCTION_V4(Build.current().isSnapshot()), /** * Support for the LIKE operator with a list of wildcards. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java index a4274bf28de4b..5eaecac12d80f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java @@ -27,7 +27,7 @@ private VectorWritables() { public static List getNamedWritables() { List entries = new ArrayList<>(); - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { entries.add(Knn.ENTRY); } if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index 62280a38ba608..a5c37c26c9041 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -303,7 +303,7 @@ public final void test() throws Throwable { ); assumeFalse( "can't use KNN function in csv tests", - testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V3.capabilityName()) + testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V4.capabilityName()) ); assumeFalse( "lookup join disabled for csv tests", diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 5d4260eb4ee66..8d79b8a91eb74 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -1239,7 +1239,7 @@ public void testFieldBasedFullTextFunctions() throws Exception { checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function"); checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3], 10)"); } } @@ -1372,7 +1372,7 @@ public void testFullTextFunctionsOnlyAllowedInWhere() throws Exception { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2], 10)", "function"); } @@ -1427,7 +1427,7 @@ public void testFullTextFunctionsDisjunctions() { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3], 10)"); } } @@ -1492,7 +1492,7 @@ public void testFullTextFunctionsWithNonBooleanFunctions() { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3], 10)", "function"); } } @@ -1563,7 +1563,7 @@ public void testFullTextFunctionsTargetsExistingField() throws Exception { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2], 10)"); } } @@ -2142,7 +2142,7 @@ public void testFullTextFunctionOptions() { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkOptionDataTypes(MultiMatch.OPTIONS, "FROM test | WHERE MULTI_MATCH(\"Jean\", title, body, {\"%s\": %s})"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], 10, {\"%s\": %s})"); } } @@ -2230,7 +2230,7 @@ public void testFullTextFunctionsNullArgs() throws Exception { checkFullTextFunctionNullArgs("term(null, \"query\")", "first"); checkFullTextFunctionNullArgs("term(title, null)", "second"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkFullTextFunctionNullArgs("knn(null, [0, 1, 2], 10)", "first"); checkFullTextFunctionNullArgs("knn(vector, null, 10)", "second"); checkFullTextFunctionNullArgs("knn(vector, [0, 1, 2], null)", "third"); @@ -2256,7 +2256,7 @@ public void testFullTextFunctionsConstantArg() throws Exception { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsConstantArg("term(title, tags)", "second"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkFullTextFunctionsConstantArg("knn(vector, vector, 10)", "second"); checkFullTextFunctionsConstantArg("knn(vector, [0, 1, 2], category)", "third"); } @@ -2287,7 +2287,7 @@ public void testFullTextFunctionsInStats() { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkFullTextFunctionsInStats("knn(vector, [0, 1, 2], 10)"); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java index 595eb58118a09..3ce00fadb95cc 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java @@ -51,7 +51,7 @@ public static Iterable parameters() { @Before public void checkCapability() { - assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); } private static List testCaseSuppliers() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index a604e1d26d313..393e77e191e1a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -1374,7 +1374,7 @@ public void testMultiMatchOptionsPushDown() { public void testKnnOptionsPushDown() { assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); - assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -1840,7 +1840,7 @@ public void testFullTextFunctionWithStatsBy(FullTextFunctionTestCase testCase) { } public void testKnnPrefilters() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -1872,7 +1872,7 @@ public void testKnnPrefilters() { } public void testKnnPrefiltersWithMultipleFilters() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -1908,7 +1908,7 @@ public void testKnnPrefiltersWithMultipleFilters() { } public void testPushDownConjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -1945,7 +1945,7 @@ public void testPushDownConjunctionsToKnnPrefilter() { } public void testPushDownNegatedConjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -1982,7 +1982,7 @@ public void testPushDownNegatedConjunctionsToKnnPrefilter() { } public void testNotPushDownDisjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -2011,7 +2011,7 @@ public void testNotPushDownDisjunctionsToKnnPrefilter() { } public void testNotPushDownKnnWithNonPushablePrefilters() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -2045,7 +2045,7 @@ public void testNotPushDownKnnWithNonPushablePrefilters() { } public void testPushDownComplexNegationsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -2095,7 +2095,7 @@ and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) } public void testMultipleKnnQueriesInPrefilters() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 7cf8f5e18d406..42c54f6f4dca5 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -7862,7 +7862,7 @@ public void testSampleNoPushDownChangePoint() { } public void testPushDownConjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); var query = """ from test @@ -7882,7 +7882,7 @@ public void testPushDownConjunctionsToKnnPrefilter() { } public void testPushDownMultipleFiltersToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); var query = """ from test @@ -7905,7 +7905,7 @@ public void testPushDownMultipleFiltersToKnnPrefilter() { } public void testNotPushDownDisjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); var query = """ from test @@ -7922,7 +7922,7 @@ public void testNotPushDownDisjunctionsToKnnPrefilter() { } public void testPushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); /* and @@ -7957,7 +7957,7 @@ public void testPushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { } public void testMorePushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); /* or @@ -7989,7 +7989,7 @@ public void testMorePushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { } public void testMultipleKnnQueriesInPrefilters() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); /* and @@ -8032,7 +8032,7 @@ public void testMultipleKnnQueriesInPrefilters() { } public void testKnnWithNonPushablePrefiltersNoScoring() { - assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); // Disjunctions with pushable conditions are allowed var plan = planTypes(""" @@ -8077,7 +8077,7 @@ public void testKnnWithNonPushablePrefiltersNoScoring() { } public void testKnnWithNonPushablePrefiltersNoScoringMultipleKnn() { - assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); // Disjunctions with pushable conditions are allowed var plan = planTypes(""" @@ -8143,7 +8143,7 @@ and match(text, "hello") and length(keyword) > 10 } public void testKnnWithNonPushablePrefiltersScoring() { - assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); // Disjunctions with pushable conditions are allowed var plan = planTypes(""" @@ -8187,7 +8187,7 @@ public void testKnnWithNonPushablePrefiltersScoring() { } public void testKnnWithNonPushablePrefiltersScoringMultipleKnns() { - assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); // Disjunctions with pushable conditions are allowed var plan = planTypes(""" @@ -8245,7 +8245,7 @@ and knn(dense_vector, [0.3, 0.4, 0.5], 7) } public void testKnnInDisjunctionsWithNonPushablePrefilters() { - assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); // Disjunctions with non-pushable conditions as a prefilter must fail assertThat( From 87285f1fdcf6f8af8482ec594636a642e7dd4679 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 13:48:08 +0200 Subject: [PATCH 12/24] capability bump --- .../src/main/resources/knn-function.csv-spec | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec index ce8061534ddbb..5f824cc1f9e05 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -3,7 +3,7 @@ # top-n query at the shard level knnSearch -required_capability: knn_function_v3 +required_capability: knn_function_v4 // tag::knn-function[] from colors metadata _score @@ -30,7 +30,7 @@ chartreuse | [127.0, 255.0, 0.0] ; knnSearchWithSimilarityOption -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where knn(rgb_vector, [255,192,203], 140, {"similarity": 40}) @@ -46,7 +46,7 @@ wheat | [245.0, 222.0, 179.0] ; knnHybridSearch -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where match(color, "blue") or knn(rgb_vector, [65,105,225], 10) @@ -68,7 +68,7 @@ yellow | [255.0, 255.0, 0.0] ; knnWithPrefilter -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where knn(rgb_vector, [128,128,0], 10) and (match(color, "olive") or match(color, "green")) @@ -82,7 +82,7 @@ green | [0.0, 128.0, 0.0] ; knnWithNegatedPrefilter -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where knn(rgb_vector, [128,128,0], 10) and not (match(color, "olive") or match(color, "chocolate")) @@ -105,7 +105,7 @@ orange | [255.0, 165.0, 0.0] ; knnAfterKeep -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | keep rgb_vector, color, _score @@ -124,7 +124,7 @@ rgb_vector:dense_vector ; knnAfterDrop -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | drop primary @@ -143,7 +143,7 @@ lime | [0.0, 255.0, 0.0] ; knnAfterEval -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | eval composed_name = locate(color, " ") > 0 @@ -162,7 +162,7 @@ golden rod | true ; knnWithConjunction -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where knn(rgb_vector, [255,255,238], 10) and hex_code like "#FFF*" @@ -181,7 +181,7 @@ yellow | #FFFF00 | [255.0, 255.0, 0.0] ; knnWithDisjunctionAndFiltersConjunction -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 10)) and primary == true @@ -204,7 +204,7 @@ yellow | [255.0, 255.0, 0.0] ; knnWithNegationsAndFiltersConjunction -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where (knn(rgb_vector, [0,255,255], 140) and not(primary == true and match(color, "blue"))) @@ -227,7 +227,7 @@ azure | [240.0, 255.0, 255.0] ; knnWithNonPushableConjunction -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | eval composed_name = locate(color, " ") > 0 @@ -251,7 +251,7 @@ maroon | false ; testKnnWithNonPushableDisjunctions -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10 @@ -267,7 +267,7 @@ papaya whip ; testKnnWithNonPushableDisjunctionsOnComplexExpressions -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false) @@ -282,7 +282,7 @@ indigo | false ; testKnnInStatsNonPushable -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors | where length(color) < 10 @@ -294,7 +294,7 @@ c: long ; testKnnInStatsWithGrouping -required_capability: knn_function_v3 +required_capability: knn_function_v4 required_capability: full_text_functions_in_stats_where from colors From aa89a4d2673d8e6d84f912ed6e095c33890af7b1 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 14:33:49 +0200 Subject: [PATCH 13/24] Don't expect blocks to be the first ones on score evaluators --- .../compute/lucene/LuceneQueryEvaluator.java | 14 ++++++++++---- .../compute/operator/ScoreOperator.java | 6 +++--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java index d91df60621fce..d268206cff3ff 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java @@ -61,10 +61,16 @@ protected LuceneQueryEvaluator(BlockFactory blockFactory, ShardConfig[] shards) } public Block executeQuery(Page page) { - // Lucene based operators retrieve DocVectors as first block - Block block = page.getBlock(0); - assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input"; - DocVector docs = (DocVector) block.asVector(); + // Search for DocVector block + Block docBlock = null; + for (int i = 0; i < page.getBlockCount(); i++) { + if (page.getBlock(i) instanceof DocBlock) { + docBlock = page.getBlock(i); + break; + } + } + assert docBlock != null : "LuceneQueryExpressionEvaluator expects a DocBlock"; + DocVector docs = (DocVector) docBlock.asVector(); try { if (docs.singleSegmentNonDecreasing()) { return evalSingleSegmentNonDecreasing(docs).asBlock(); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java index 2afc885d71124..e52c1276d540c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java @@ -46,9 +46,9 @@ public ScoreOperator(BlockFactory blockFactory, ExpressionScorer scorer, int sco @Override protected Page process(Page page) { - assert page.getBlockCount() >= 2 : "Expected at least 2 blocks, got " + page.getBlockCount(); - assert page.getBlock(0).asVector() instanceof DocVector : "Expected a DocVector, got " + page.getBlock(0).asVector(); - assert page.getBlock(1).asVector() instanceof DoubleVector : "Expected a DoubleVector, got " + page.getBlock(1).asVector(); + assert page.getBlockCount() > scoreBlockPosition : "Expected to get a score block in position " + scoreBlockPosition; + assert page.getBlock(scoreBlockPosition).asVector() instanceof DoubleVector + : "Expected a DoubleVector as a score block, got " + page.getBlock(scoreBlockPosition).asVector(); Block[] blocks = new Block[page.getBlockCount()]; for (int i = 0; i < page.getBlockCount(); i++) { From beea0127fe71689ad54dda350e631f5c95d741c3 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 17:35:29 +0200 Subject: [PATCH 14/24] Move around optimizer rules. Avoid doing optimization for invalid rule that will be detected later. --- .../esql/optimizer/LogicalPlanOptimizer.java | 6 +++--- .../ReplaceKnnWithNoPushedDownFilters.java | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index 6c69520e2d2ab..664d1e5924cbb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -157,9 +157,7 @@ protected static Batch substitutions() { new ReplaceAliasingEvalWithProject(), new SkipQueryOnEmptyMappings(), new SubstituteSurrogateExpressions(), - new ReplaceOrderByExpressionWithEval(), - new PushDownConjunctionsToKnnPrefilters(), - new ReplaceKnnWithNoPushedDownFilters() + new ReplaceOrderByExpressionWithEval() // new NormalizeAggregate(), - waits on https://github.com/elastic/elasticsearch/issues/100634 ); } @@ -196,6 +194,8 @@ protected static Batch operators(boolean local) { new PruneLiteralsInOrderBy(), new PushDownAndCombineLimits(), new PushDownAndCombineFilters(), + new PushDownConjunctionsToKnnPrefilters(), + new ReplaceKnnWithNoPushedDownFilters(), new PushDownAndCombineSample(), new PushDownInferencePlan(), new PushDownEval(), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java index 6b1de8f9ae838..f2b373aead068 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.esql.expression.function.vector.ExactNN; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.plan.logical.Eval; @@ -30,6 +31,7 @@ import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD; @@ -59,6 +61,20 @@ protected LogicalPlan rule(Filter filter) { return filter; } + // Check that knn is not part of a disjunction + Holder hasNonPushableDisjunctions = new Holder<>(false); + filter.condition().forEachDown(Or.class, or -> { + or.forEachDown(Knn.class, knn -> { + Collection nonPushableFilters = knn.nonPushableFilters(); + if (nonPushableFilters.isEmpty() == false) { + hasNonPushableDisjunctions.set(true); + } + }); + }); + if (hasNonPushableDisjunctions.get()) { + return filter; + } + List scoreAttrs; LogicalPlan scoringPlan; if (usesScoring(filter)) { From bf05dddef99bc2197c4b40ff1d5da6164c9a0655 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 17:35:35 +0200 Subject: [PATCH 15/24] Add CSV tests --- .../src/main/resources/knn-function.csv-spec | 58 ++++++++++++------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec index 5f824cc1f9e05..0a7cd0492b942 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -231,54 +231,68 @@ required_capability: knn_function_v4 from colors metadata _score | eval composed_name = locate(color, " ") > 0 -| where knn(rgb_vector, [128,128,0], 140) and composed_name == false +| where knn(rgb_vector, [128,128,0], 10) and composed_name == false | sort _score desc, color asc -| keep color, composed_name -| limit 10 +| keep color ; -color:text | composed_name:boolean -olive | false -sienna | false -chocolate | false -peru | false -brown | false -firebrick | false -chartreuse | false -gray | false -green | false -maroon | false +color:text +olive +sienna +chocolate +peru +brown +firebrick +chartreuse +green +maroon ; testKnnWithNonPushableDisjunctions required_capability: knn_function_v4 from colors metadata _score -| where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10 +| where knn(rgb_vector, [128,128,0], 10) or length(color) > 10 | sort _score desc, color asc | keep color ; color:text olive +sienna +chocolate +peru +golden rod +brown +firebrick +chartreuse +green +maroon aqua marine lemon chiffon papaya whip ; -testKnnWithNonPushableDisjunctionsOnComplexExpressions +testKnnWithNonPushableConjunctionsOnComplexExpressions required_capability: knn_function_v4 from colors metadata _score -| where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false) +| where knn(rgb_vector, [128,128,0], 10) and length(color) < 7 and knn(rgb_vector, [128,0,128], 10) and primary == false | sort _score desc, color asc -| keep color, primary +| keep color ; -color:text | primary:boolean -olive | false -purple | false -indigo | false +color:text +olive +purple +indigo +sienna +brown +peru +maroon +navy +tomato +orange ; testKnnInStatsNonPushable From 9e673f3a967c28e8cd0c482a27a5334ef104227e Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 18:16:56 +0200 Subject: [PATCH 16/24] Register and fix ExactNN --- .../xpack/esql/expression/function/EsqlFunctionRegistry.java | 4 +++- .../xpack/esql/expression/function/vector/ExactNN.java | 4 ++-- .../esql/expression/function/vector/VectorWritables.java | 2 ++ 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 0c2629596a9b4..1343aa5c66a79 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -181,6 +181,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim; import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay; import org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity; +import org.elasticsearch.xpack.esql.expression.function.vector.ExactNN; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.parser.ParsingException; import org.elasticsearch.xpack.esql.session.Configuration; @@ -491,7 +492,8 @@ private static FunctionDefinition[][] snapshotFunctions() { def(StGeohex.class, StGeohex::new, "st_geohex"), def(StGeohexToLong.class, StGeohexToLong::new, "st_geohex_to_long"), def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string"), - def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine") } }; + def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine"), + def(ExactNN.class, tri(ExactNN::new), "exact_nn") } }; } public EsqlFunctionRegistry snapshotRegistry() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java index 347e71df5098a..0942cb11f238d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java @@ -185,7 +185,7 @@ private static ExactNN readFrom(StreamInput in) throws IOException { Source source = Source.readFrom((PlanStreamInput) in); Expression field = in.readNamedWriteable(Expression.class); Expression query = in.readNamedWriteable(Expression.class); - Expression minimumSimilarity = in.readNamedWriteable(Expression.class); + Expression minimumSimilarity = in.readOptionalNamedWriteable(Expression.class); QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class); return new ExactNN(source, field, query, minimumSimilarity, queryBuilder); } @@ -195,7 +195,7 @@ public void writeTo(StreamOutput out) throws IOException { source().writeTo(out); out.writeNamedWriteable(field()); out.writeNamedWriteable(query()); - out.writeNamedWriteable(minimumSimilarity()); + out.writeOptionalNamedWriteable(minimumSimilarity()); out.writeOptionalNamedWriteable(queryBuilder()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java index 5eaecac12d80f..a3fb133bbde33 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java @@ -29,6 +29,8 @@ public static List getNamedWritables() { if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { entries.add(Knn.ENTRY); + // ExactNN is needed as a KNN optimization + entries.add(ExactNN.ENTRY); } if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { entries.add(CosineSimilarity.ENTRY); From 0ff8c238ef24f8958645de6c6d21fe2b80cdb2dc Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 18:17:09 +0200 Subject: [PATCH 17/24] Added a projection to get rid of extra columns --- .../ReplaceKnnWithNoPushedDownFilters.java | 45 +++---- .../optimizer/LogicalPlanOptimizerTests.java | 117 ++---------------- 2 files changed, 27 insertions(+), 135 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java index f2b373aead068..df69634f6977c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java @@ -13,7 +13,6 @@ import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.MapExpression; -import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.Order; @@ -27,6 +26,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.TopN; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; @@ -38,10 +38,9 @@ import static org.elasticsearch.xpack.esql.core.expression.Attribute.rawTemporaryName; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.UP; -import static org.elasticsearch.xpack.esql.planner.PlannerUtils.usesScoring; /** - * Break TopN back into Limit + OrderBy to allow the order rules to kick in. + * Replaces KNN queries with non pushable prefilters used in filters */ public class ReplaceKnnWithNoPushedDownFilters extends OptimizerRules.OptimizerRule { @@ -75,36 +74,26 @@ protected LogicalPlan rule(Filter filter) { return filter; } - List scoreAttrs; - LogicalPlan scoringPlan; - if (usesScoring(filter)) { - scoreAttrs = filter.output() - .stream() - .filter(attr -> attr instanceof MetadataAttribute ma && ma.name().equals(MetadataAttribute.SCORE)) - .toList(); - // Use the original filter, changing knn to exact queries - scoringPlan = filter.with( - filter.condition().transformDown(Knn.class, ReplaceKnnWithNoPushedDownFilters::replaceKnnByExactQuery) - ); - } else { - // Replace knn with scoring expressions of exact queries - List exactQueries = knnQueries.get() - .stream() - .map(ReplaceKnnWithNoPushedDownFilters::replaceKnnByExactQuery) - .toList(); - assert exactQueries.isEmpty() == false; - - // Create an Eval for scoring the exact queries - List exactScoreAliases = exactQueryScoreAliases(exactQueries); - scoringPlan = new Eval(EMPTY, filter.with(conditionWithoutKnns), exactScoreAliases); - scoreAttrs = exactScoreAliases.stream().map(Alias::toAttribute).toList(); - } + // Replace knn with scoring expressions of exact queries + List exactQueries = knnQueries.get() + .stream() + .map(ReplaceKnnWithNoPushedDownFilters::replaceKnnByExactQuery) + .toList(); + assert exactQueries.isEmpty() == false; + + // Create an Eval for scoring the exact queries + List exactScoreAliases = exactQueryScoreAliases(exactQueries); + LogicalPlan scoringPlan = new Eval(EMPTY, filter.with(conditionWithoutKnns), exactScoreAliases); + List scoreAttrs = exactScoreAliases.stream().map(Alias::toAttribute).toList(); // Sort on the scores, limit on the minimum k from the queries TopN topN = createTopN(scoreAttrs, knnQueries.get(), scoringPlan); // Filter on scores > 0. We could filter earlier, but could be combined with the existing filter and _score would not be updated - return createScoreFilter(scoreAttrs, topN); + Filter scoreFilter = createScoreFilter(scoreAttrs, topN); + + // Drop the scores + return new Project(EMPTY, scoreFilter, filter.output()); } private static Expression replaceKnnByExactQuery(Knn knn) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 42c54f6f4dca5..f4ea708e6c366 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -34,7 +34,6 @@ import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.MapExpression; -import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; @@ -8040,7 +8039,10 @@ public void testKnnWithNonPushablePrefiltersNoScoring() { | where knn(dense_vector, [0.1, 0.2, 0.3], 5) and match(text, "hello") and length(keyword) > 10 """); - var limit = as(plan, Limit.class); + var project = as(plan, Project.class); + assertFalse(project.projections().stream().anyMatch(p -> p.toString().contains(EXACT_SCORE_ATTR_NAME))); + + var limit = as(project.child(), Limit.class); assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); // Next: Filter[$$knn_score$0 > 0.0] @@ -8076,17 +8078,20 @@ public void testKnnWithNonPushablePrefiltersNoScoring() { var esRelation = as(prefilter.child(), EsRelation.class); } - public void testKnnWithNonPushablePrefiltersNoScoringMultipleKnn() { + public void testKnnWithNonPushablePrefiltersScoringMultipleKnn() { assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); // Disjunctions with pushable conditions are allowed var plan = planTypes(""" - from types + from types metadata _score | where knn(dense_vector, [0.1, 0.2, 0.3], 5) and knn(dense_vector, [0.4, 0.5, 0.6], 7) and match(text, "hello") and length(keyword) > 10 """); - var limit = as(plan, Limit.class); + var project = as(plan, Project.class); + assertFalse(project.projections().stream().anyMatch(p -> p.toString().contains(EXACT_SCORE_ATTR_NAME))); + + var limit = as(project.child(), Limit.class); assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); // Next: Filter[$$knn_score$1 > 0.0 AND $$knn_score$0 > 0.0] @@ -8142,108 +8147,6 @@ and match(text, "hello") and length(keyword) > 10 var esRelation = as(prefilter.child(), EsRelation.class); } - public void testKnnWithNonPushablePrefiltersScoring() { - assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); - - // Disjunctions with pushable conditions are allowed - var plan = planTypes(""" - from types metadata _score - | where knn(dense_vector, [0.1, 0.2, 0.3], 5) and match(text, "hello") and length(keyword) > 10 - """); - - // Top-level: Limit[1000[INTEGER],false] - var limit = as(plan, Limit.class); - assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); - - // Next: Filter[_score > 0.0] - var filter = as(limit.child(), Filter.class); - var gt = as(filter.condition(), GreaterThan.class); - MetadataAttribute scoreAttr = as(gt.left(), MetadataAttribute.class); - assertThat(scoreAttr.name(), equalTo("_score")); - assertThat(gt.right().fold(FoldContext.small()), equalTo(0.0)); - - // Next: TopN[..., 5] - var topN = as(filter.child(), TopN.class); - assertThat(topN.limit().fold(FoldContext.small()), equalTo(5)); - assertThat(topN.order().getFirst().child(), equalTo(scoreAttr)); - - // Next: Filter[EXACTNN(...) AND MATCH(...) AND LENGTH(keyword) > 10] - var innerFilter = as(topN.child(), Filter.class); - var and1 = as(innerFilter.condition(), And.class); - var and2 = as(and1.left(), And.class); - - var exactNN = as(and2.left(), ExactNN.class); - var field = as(exactNN.field(), FieldAttribute.class); - assertThat(field.name(), equalTo("dense_vector")); - assertThat(exactNN.query().toString(), equalTo("[0.1, 0.2, 0.3]")); - var match = as(and2.right(), Match.class); - var lenGt = as(and1.right(), GreaterThan.class); - - assertThat(Expressions.name(lenGt.left()), containsString("length(keyword)")); - assertThat(lenGt.right().fold(FoldContext.small()), equalTo(10)); - - // Next: EsRelation[types] - var esRelation = as(innerFilter.child(), EsRelation.class); - } - - public void testKnnWithNonPushablePrefiltersScoringMultipleKnns() { - assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); - - // Disjunctions with pushable conditions are allowed - var plan = planTypes(""" - from types metadata _score - | where knn(dense_vector, [0.1, 0.2, 0.3], 5) and match(text, "hello") and length(keyword) > 10 - and knn(dense_vector, [0.3, 0.4, 0.5], 7) - """); - - // Top-level: Limit[1000[INTEGER],false] - var limit = as(plan, Limit.class); - assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); - - // Next: Filter[_score > 0.0] - var filter = as(limit.child(), Filter.class); - var gt = as(filter.condition(), GreaterThan.class); - MetadataAttribute scoreAttr = as(gt.left(), MetadataAttribute.class); - assertThat(scoreAttr.name(), equalTo("_score")); - assertThat(gt.right().fold(FoldContext.small()), equalTo(0.0)); - - // Next: TopN[..., 5] - var topN = as(filter.child(), TopN.class); - assertThat(topN.limit().fold(FoldContext.small()), equalTo(5)); - assertThat(topN.order().getFirst().child(), equalTo(scoreAttr)); - - // Next: Filter[EXACTNN(...) AND MATCH(...) AND LENGTH(keyword) > 10 AND EXACTNN(...)] - var innerFilter = as(topN.child(), Filter.class); - // Should be a chain of ANDs - var and1 = as(innerFilter.condition(), And.class); - var and2 = as(and1.left(), And.class); - var and3 = as(and2.left(), And.class); - - // First EXACTNN - var exactNN1 = as(and3.left(), ExactNN.class); - var field1 = as(exactNN1.field(), FieldAttribute.class); - assertThat(field1.name(), equalTo("dense_vector")); - assertThat(exactNN1.query().toString(), equalTo("[0.1, 0.2, 0.3]")); - - // MATCH - var match = as(and3.right(), Match.class); - - // LENGTH(keyword) > 10 - var lenGt = as(and2.right(), GreaterThan.class); - assertThat(Expressions.name(lenGt.left()), containsString("length(keyword)")); - assertThat(lenGt.right().fold(FoldContext.small()), equalTo(10)); - - // Second EXACTNN - var exactNN2 = as(and1.right(), ExactNN.class); - var field2 = as(exactNN2.field(), FieldAttribute.class); - assertThat(field2.name(), equalTo("dense_vector")); - assertThat(exactNN2.query().toString(), equalTo("[0.3, 0.4, 0.5]")); - - // Next: EsRelation[types] - var esRelation = as(innerFilter.child(), EsRelation.class); - - } - public void testKnnInDisjunctionsWithNonPushablePrefilters() { assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); From 521bff54d560cb03fd174042dc67bdafdbedcdc0 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 18:22:59 +0200 Subject: [PATCH 18/24] Spotless --- .../compute/operator/ScoreOperator.java | 1 - .../ReplaceKnnWithNoPushedDownFilters.java | 17 ++++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java index e52c1276d540c..1c3d522fda5ab 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java @@ -9,7 +9,6 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.Page; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java index df69634f6977c..437952a48e098 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java @@ -40,7 +40,17 @@ import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.UP; /** - * Replaces KNN queries with non pushable prefilters used in filters + * Replaces KNN queries with non pushable prefilters used in filters. + * + * A query like: + * WHERE knn(field1, [..], 10) AND non-pushable-filter + * + * Will be replaced with: + * | EVAL knn_score = SCORE(exact_nn(field1, [..])) + * | WHERE non-pushable-filter + * | TOPN 10 knn_score DESC + * | WHERE knn_score > 0 + * | DROP knn_score */ public class ReplaceKnnWithNoPushedDownFilters extends OptimizerRules.OptimizerRule { @@ -75,10 +85,7 @@ protected LogicalPlan rule(Filter filter) { } // Replace knn with scoring expressions of exact queries - List exactQueries = knnQueries.get() - .stream() - .map(ReplaceKnnWithNoPushedDownFilters::replaceKnnByExactQuery) - .toList(); + List exactQueries = knnQueries.get().stream().map(ReplaceKnnWithNoPushedDownFilters::replaceKnnByExactQuery).toList(); assert exactQueries.isEmpty() == false; // Create an Eval for scoring the exact queries From 40fa3872870bb4ffac93f1ba0d2e190d3018d855 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 18:32:23 +0200 Subject: [PATCH 19/24] Small docs change --- .../rules/logical/ReplaceKnnWithNoPushedDownFilters.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java index 437952a48e098..f6d16002903bd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java @@ -46,8 +46,8 @@ * WHERE knn(field1, [..], 10) AND non-pushable-filter * * Will be replaced with: - * | EVAL knn_score = SCORE(exact_nn(field1, [..])) * | WHERE non-pushable-filter + * | EVAL knn_score = SCORE(exact_nn(field1, [..])) * | TOPN 10 knn_score DESC * | WHERE knn_score > 0 * | DROP knn_score @@ -96,7 +96,7 @@ protected LogicalPlan rule(Filter filter) { // Sort on the scores, limit on the minimum k from the queries TopN topN = createTopN(scoreAttrs, knnQueries.get(), scoringPlan); - // Filter on scores > 0. We could filter earlier, but could be combined with the existing filter and _score would not be updated + // Filter on scores > 0 Filter scoreFilter = createScoreFilter(scoreAttrs, topN); // Drop the scores From b11c46587cc218634355e13ecbf90025d927123b Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 22 Jul 2025 19:00:07 +0200 Subject: [PATCH 20/24] Spotless --- .../xpack/esql/optimizer/LogicalPlanOptimizerTests.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index eaac50da07c91..d35f97c36795b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -182,9 +182,9 @@ import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GTE; import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LT; import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LTE; -import static org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceKnnWithNoPushedDownFilters.EXACT_SCORE_ATTR_NAME; import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.DOWN; import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.UP; +import static org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceKnnWithNoPushedDownFilters.EXACT_SCORE_ATTR_NAME; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.contains; @@ -8125,7 +8125,6 @@ public List output() { assertThat(e.getMessage(), containsString("Output has changed from")); } - public void testKnnWithNonPushablePrefiltersNoScoring() { assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); From e9d3ba14b78fd7a018e6925472f81a29e133d3fa Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 23 Jul 2025 08:59:27 +0200 Subject: [PATCH 21/24] Fix test --- .../LocalPhysicalPlanOptimizerTests.java | 70 +++++++++++++++---- 1 file changed, 56 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index f2fc051fb497a..e2f35a9afe98f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -65,8 +65,9 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchOperator; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; +import org.elasticsearch.xpack.esql.expression.function.fulltext.Score; +import org.elasticsearch.xpack.esql.expression.function.vector.ExactNN; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; -import org.elasticsearch.xpack.esql.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; @@ -143,6 +144,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS; import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; import static org.elasticsearch.xpack.esql.core.util.TestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceKnnWithNoPushedDownFilters.EXACT_SCORE_ATTR_NAME; import static org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests.randomEstimatedRowSize; import static org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec.StatsType; import static org.hamcrest.Matchers.contains; @@ -2022,19 +2024,59 @@ public void testNotPushDownKnnWithNonPushablePrefilters() { """; var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); - var limit = as(plan, LimitExec.class); - var exchange = as(limit.child(), ExchangeExec.class); - var project = as(exchange.child(), ProjectExec.class); - var field = as(project.child(), FieldExtractExec.class); - var secondLimit = as(field.child(), LimitExec.class); - var filter = as(secondLimit.child(), FilterExec.class); - var and = as(filter.condition(), And.class); - var knn = as(and.left(), Knn.class); - assertEquals("(keyword == \"test\") or length(text) > 10", knn.filterExpressions().get(0).toString()); - assertEquals("integer > 10", knn.filterExpressions().get(1).toString()); + var project = as (plan, ProjectExec.class); + assertFalse(project.projections().stream().anyMatch(p -> p.toString().contains(EXACT_SCORE_ATTR_NAME))); - var fieldExtract = as(filter.child(), FieldExtractExec.class); - var queryExec = as(fieldExtract.child(), EsQueryExec.class); + + // LimitExec + var limit = as(project.child(), LimitExec.class); + assertThat(as(limit.limit(), Literal.class).value(), is(1000)); + + // FilterExec on $$knn_score$0 > 0.0 + var filter = as(limit.child(), FilterExec.class); + var gt = as(filter.condition(), GreaterThan.class); + ReferenceAttribute scoreAttr = as(gt.left(), ReferenceAttribute.class); + assertThat(scoreAttr.name(), containsString(EXACT_SCORE_ATTR_NAME)); + assertThat(gt.right().fold(FoldContext.small()), is(0.0)); + + // TopNExec on $$knn_score$0 desc + var topN = as(filter.child(), TopNExec.class); + assertThat(as(topN.limit(), Literal.class).value(), is(10)); + assertThat(Expressions.name(topN.order().getFirst().child()), equalTo(scoreAttr.name())); + + // ExchangeExec + var exchange = as(topN.child(), ExchangeExec.class); + + // ProjectExec (with score column) + var project2 = as(exchange.child(), ProjectExec.class); + assertTrue(project2.output().contains(scoreAttr)); + + var fieldExtract = as(project2.child(), FieldExtractExec.class); + + var topN2 = as(fieldExtract.child(), TopNExec.class); + + // EvalExec for score + var eval = as(topN2.child(), EvalExec.class); + var scoreAlias = as(eval.fields().getFirst(), Alias.class); + assertThat(scoreAlias.name(), containsString(EXACT_SCORE_ATTR_NAME)); + var score = as(scoreAlias.child(), Score.class); + var exactNN = as(score.children().getFirst(), ExactNN.class); + var field = as(exactNN.field(), FieldAttribute.class); + assertThat(field.name(), equalTo("dense_vector")); + assertThat(exactNN.query().toString(), equalTo("[0.0, 1.0, 2.0]")); + + // FieldExtractExec for dense_vector + var fieldExtract2 = as(eval.child(), FieldExtractExec.class); + + // FilterExec for OR + var filter2 = as(fieldExtract2.child(), FilterExec.class); + var or = as(filter2.condition(), Or.class); + + // FieldExtractExec for keyword, text + var fieldExtract3 = as(filter2.child(), FieldExtractExec.class); + + // EsQueryExec for integer > 10 + var esQuery = as(fieldExtract3.child(), EsQueryExec.class); // The query should only contain the pushable condition QueryBuilder integerGtQuery = wrapWithSingleQuery( @@ -2044,7 +2086,7 @@ public void testNotPushDownKnnWithNonPushablePrefilters() { new Source(2, 47, "integer > 10") ); - assertEquals(integerGtQuery.toString(), queryExec.query().toString()); + assertEquals(integerGtQuery.toString(), esQuery.query().toString()); } public void testPushDownComplexNegationsToKnnPrefilter() { From 8d8b7fbef692894892fe496a3dde44f69db76de4 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 23 Jul 2025 09:14:16 +0200 Subject: [PATCH 22/24] Fix test --- .../src/main/resources/knn-function.csv-spec | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec index 0a7cd0492b942..de283dfc7d34f 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -283,16 +283,16 @@ from colors metadata _score ; color:text -olive -purple -indigo -sienna brown -peru +coral +gold maroon -navy -tomato +olive orange +peru +salmon +sienna +tomato ; testKnnInStatsNonPushable From 8f1fb3a7d7f2033317b48a78f6e0f9cf561d5824 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 23 Jul 2025 07:22:07 +0000 Subject: [PATCH 23/24] [CI] Auto commit changes from spotless --- .../xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index e2f35a9afe98f..997fe859632d1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -2024,10 +2024,9 @@ public void testNotPushDownKnnWithNonPushablePrefilters() { """; var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); - var project = as (plan, ProjectExec.class); + var project = as(plan, ProjectExec.class); assertFalse(project.projections().stream().anyMatch(p -> p.toString().contains(EXACT_SCORE_ATTR_NAME))); - // LimitExec var limit = as(project.child(), LimitExec.class); assertThat(as(limit.limit(), Literal.class).value(), is(1000)); From b132338134660df9702b069c42d9422ef209b072 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 23 Jul 2025 10:35:35 +0200 Subject: [PATCH 24/24] Fix test --- .../src/main/resources/knn-function.csv-spec | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec index de283dfc7d34f..0be345645401b 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -231,7 +231,7 @@ required_capability: knn_function_v4 from colors metadata _score | eval composed_name = locate(color, " ") > 0 -| where knn(rgb_vector, [128,128,0], 10) and composed_name == false +| where knn(rgb_vector, [100,100,0], 10) and composed_name == false | sort _score desc, color asc | keep color ; @@ -239,13 +239,13 @@ from colors metadata _score color:text olive sienna -chocolate -peru brown -firebrick -chartreuse green maroon +firebrick +chocolate +peru +gray ; testKnnWithNonPushableDisjunctions