Skip to content

ESQL - KNN functions with non-pushed down filters #131708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7f313f4
Add exact NN query infra
carlosdelest Jul 18, 2025
2b1a4fa
First version ReplaceKnnWithNoPushedDownFiltersWithEvalTopN
carlosdelest Jul 21, 2025
363e50e
Refactoring and spotless
carlosdelest Jul 22, 2025
a60aa3d
Add _score use case
carlosdelest Jul 22, 2025
63c62c7
Check knn usage in disjunctions for non pushable filters
carlosdelest Jul 22, 2025
4440717
Rename
carlosdelest Jul 22, 2025
dee1e91
Move ReplaceKnnWithNoPushedDownFilters to logical optimizer
carlosdelest Jul 22, 2025
6438bc1
Add tests
carlosdelest Jul 22, 2025
21bb1f1
Fix knn refs
carlosdelest Jul 22, 2025
376be41
Add tests
carlosdelest Jul 22, 2025
a24645d
capability bump
carlosdelest Jul 22, 2025
87285f1
capability bump
carlosdelest Jul 22, 2025
aa89a4d
Don't expect blocks to be the first ones on score evaluators
carlosdelest Jul 22, 2025
beea012
Move around optimizer rules. Avoid doing optimization for invalid rul…
carlosdelest Jul 22, 2025
bf05ddd
Add CSV tests
carlosdelest Jul 22, 2025
9e673f3
Register and fix ExactNN
carlosdelest Jul 22, 2025
0ff8c23
Added a projection to get rid of extra columns
carlosdelest Jul 22, 2025
521bff5
Spotless
carlosdelest Jul 22, 2025
40fa387
Small docs change
carlosdelest Jul 22, 2025
605f8b0
Merge remote-tracking branch 'origin/main' into non-issue/knn-prefilt…
carlosdelest Jul 22, 2025
b11c465
Spotless
carlosdelest Jul 22, 2025
e9d3ba1
Fix test
carlosdelest Jul 23, 2025
8d8b7fb
Fix test
carlosdelest Jul 23, 2025
5d96906
Merge remote-tracking branch 'origin/main' into non-issue/knn-prefilt…
carlosdelest Jul 23, 2025
8f1fb3a
[CI] Auto commit changes from spotless
Jul 23, 2025
b132338
Fix test
carlosdelest Jul 23, 2025
03d4c21
Merge remote-tracking branch 'carlosdelest/non-issue/knn-prefilter-no…
carlosdelest Jul 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,16 @@ protected LuceneQueryEvaluator(BlockFactory blockFactory, ShardConfig[] shards)
}

public Block executeQuery(Page page) {
// Lucene based operators retrieve DocVectors as first block
Block block = page.getBlock(0);
assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input";
DocVector docs = (DocVector) block.asVector();
// Search for DocVector block
Block docBlock = null;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made LuceneQueryEvaluator more robust

for (int i = 0; i < page.getBlockCount(); i++) {
if (page.getBlock(i) instanceof DocBlock) {
docBlock = page.getBlock(i);
break;
}
}
assert docBlock != null : "LuceneQueryExpressionEvaluator expects a DocBlock";
DocVector docs = (DocVector) docBlock.asVector();
try {
if (docs.singleSegmentNonDecreasing()) {
return evalSingleSegmentNonDecreasing(docs).asBlock();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DocVector;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.Page;
Expand Down Expand Up @@ -46,9 +45,9 @@ public ScoreOperator(BlockFactory blockFactory, ExpressionScorer scorer, int sco

@Override
protected Page process(Page page) {
assert page.getBlockCount() >= 2 : "Expected at least 2 blocks, got " + page.getBlockCount();
assert page.getBlock(0).asVector() instanceof DocVector : "Expected a DocVector, got " + page.getBlock(0).asVector();
assert page.getBlock(1).asVector() instanceof DoubleVector : "Expected a DoubleVector, got " + page.getBlock(1).asVector();
assert page.getBlockCount() > scoreBlockPosition : "Expected to get a score block in position " + scoreBlockPosition;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was an uncovered bug

assert page.getBlock(scoreBlockPosition).asVector() instanceof DoubleVector
: "Expected a DoubleVector as a score block, got " + page.getBlock(scoreBlockPosition).asVector();

Block[] blocks = new Block[page.getBlockCount()];
for (int i = 0; i < page.getBlockCount(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# top-n query at the shard level

knnSearch
required_capability: knn_function_v3
required_capability: knn_function_v4

// tag::knn-function[]
from colors metadata _score
Expand All @@ -30,7 +30,7 @@ chartreuse | [127.0, 255.0, 0.0]
;

knnSearchWithSimilarityOption
required_capability: knn_function_v3
required_capability: knn_function_v4

from colors metadata _score
| where knn(rgb_vector, [255,192,203], 140, {"similarity": 40})
Expand All @@ -46,7 +46,7 @@ wheat | [245.0, 222.0, 179.0]
;

knnHybridSearch
required_capability: knn_function_v3
required_capability: knn_function_v4

from colors metadata _score
| where match(color, "blue") or knn(rgb_vector, [65,105,225], 10)
Expand All @@ -68,7 +68,7 @@ yellow | [255.0, 255.0, 0.0]
;

knnWithPrefilter
required_capability: knn_function_v3
required_capability: knn_function_v4

from colors metadata _score
| where knn(rgb_vector, [128,128,0], 10) and (match(color, "olive") or match(color, "green"))
Expand All @@ -82,7 +82,7 @@ green | [0.0, 128.0, 0.0]
;

knnWithNegatedPrefilter
required_capability: knn_function_v3
required_capability: knn_function_v4

from colors metadata _score
| where knn(rgb_vector, [128,128,0], 10) and not (match(color, "olive") or match(color, "chocolate"))
Expand All @@ -105,7 +105,7 @@ orange | [255.0, 165.0, 0.0]
;

knnAfterKeep
required_capability: knn_function_v3
required_capability: knn_function_v4

from colors metadata _score
| keep rgb_vector, color, _score
Expand All @@ -124,7 +124,7 @@ rgb_vector:dense_vector
;

knnAfterDrop
required_capability: knn_function_v3
required_capability: knn_function_v4

from colors metadata _score
| drop primary
Expand All @@ -143,7 +143,7 @@ lime | [0.0, 255.0, 0.0]
;

knnAfterEval
required_capability: knn_function_v3
required_capability: knn_function_v4

from colors metadata _score
| eval composed_name = locate(color, " ") > 0
Expand All @@ -162,7 +162,7 @@ golden rod | true
;

knnWithConjunction
required_capability: knn_function_v3
required_capability: knn_function_v4

from colors metadata _score
| where knn(rgb_vector, [255,255,238], 10) and hex_code like "#FFF*"
Expand All @@ -181,7 +181,7 @@ yellow | #FFFF00 | [255.0, 255.0, 0.0]
;

knnWithDisjunctionAndFiltersConjunction
required_capability: knn_function_v3
required_capability: knn_function_v4

from colors metadata _score
| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 10)) and primary == true
Expand All @@ -204,7 +204,7 @@ yellow | [255.0, 255.0, 0.0]
;

knnWithNegationsAndFiltersConjunction
required_capability: knn_function_v3
required_capability: knn_function_v4

from colors metadata _score
| where (knn(rgb_vector, [0,255,255], 140) and not(primary == true and match(color, "blue")))
Expand All @@ -227,62 +227,76 @@ azure | [240.0, 255.0, 255.0]
;

knnWithNonPushableConjunction
required_capability: knn_function_v3
required_capability: knn_function_v4

from colors metadata _score
| eval composed_name = locate(color, " ") > 0
| where knn(rgb_vector, [128,128,0], 140) and composed_name == false
| where knn(rgb_vector, [128,128,0], 10) and composed_name == false
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can see the change in action - we no longer need to use a large number for k to maintain semantics, nor to use limit at the end.

| sort _score desc, color asc
| keep color, composed_name
| limit 10
| keep color
;

color:text | composed_name:boolean
olive | false
sienna | false
chocolate | false
peru | false
brown | false
firebrick | false
chartreuse | false
gray | false
green | false
maroon | false
color:text
olive
sienna
chocolate
peru
brown
firebrick
chartreuse
green
maroon
;

testKnnWithNonPushableDisjunctions
required_capability: knn_function_v3
required_capability: knn_function_v4

from colors metadata _score
| where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10
| where knn(rgb_vector, [128,128,0], 10) or length(color) > 10
| sort _score desc, color asc
| keep color
;

color:text
olive
sienna
chocolate
peru
golden rod
brown
firebrick
chartreuse
green
maroon
aqua marine
lemon chiffon
papaya whip
;

testKnnWithNonPushableDisjunctionsOnComplexExpressions
required_capability: knn_function_v3
testKnnWithNonPushableConjunctionsOnComplexExpressions
required_capability: knn_function_v4

from colors metadata _score
| where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false)
| where knn(rgb_vector, [128,128,0], 10) and length(color) < 7 and knn(rgb_vector, [128,0,128], 10) and primary == false
| sort _score desc, color asc
| keep color, primary
| keep color
;

color:text | primary:boolean
olive | false
purple | false
indigo | false
color:text
olive
purple
indigo
sienna
brown
peru
maroon
navy
tomato
orange
;

testKnnInStatsNonPushable
required_capability: knn_function_v3
required_capability: knn_function_v4

from colors
| where length(color) < 10
Expand All @@ -294,7 +308,7 @@ c: long
;

testKnnInStatsWithGrouping
required_capability: knn_function_v3
required_capability: knn_function_v4
required_capability: full_text_functions_in_stats_where

from colors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
Expand All @@ -31,7 +32,9 @@
import static org.elasticsearch.index.IndexMode.LOOKUP;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.Matchers.equalTo;

@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE", reason = "debug")
public class KnnFunctionIT extends AbstractEsqlIntegTestCase {

private final Map<Integer, List<Float>> indexedVectors = new HashMap<>();
Expand Down Expand Up @@ -157,9 +160,71 @@ public void testKnnWithLookupJoin() {
);
}

public void testKnnNotPushedDown() {
float[] queryVector = new float[numDims];
Arrays.fill(queryVector, 1.0f);

// We retrieve 5 from knn, and 5 from the non-pushed down disjunction. They are disjoint so we get 10 as a result
var query = String.format(Locale.ROOT, """
FROM test
| WHERE knn(vector, %s, 5) OR (length(keyword) > 5 AND length(keyword) <= 10)
| KEEP id, vector, keyword
| SORT id ASC
| LIMIT 20
""", Arrays.toString(queryVector));

try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "vector", "keyword"));
assertColumnTypes(resp.columns(), List.of("integer", "dense_vector", "keyword"));

List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
assertEquals(10, valuesList.size());
}
}

public void testKnnPrefiltersNotPushedDown() {
float[] queryVector = new float[numDims];
Arrays.fill(queryVector, 1.0f);

// We retrieve 5 from knn, but must be prefiltered with the non-pushed down conjunction
var query = String.format(Locale.ROOT, """
FROM test
| WHERE knn(vector, %s, 5) AND length(keyword) > 5 AND length(keyword) <= 10
| SORT id ASC
| LIMIT 20
""", Arrays.toString(queryVector));

try (var resp = run(query)) {
// No added columns
assertThat(resp.columns().size(), equalTo(4));
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
assertEquals(5, valuesList.size());
}
}

public void testKnnPrefiltersNotPushedDownWithScoring() {
float[] queryVector = new float[numDims];
Arrays.fill(queryVector, 1.0f);

// We retrieve 5 from knn, but must be prefiltered with the non-pushed down conjunction
var query = String.format(Locale.ROOT, """
FROM test METADATA _score
| WHERE knn(vector, %s, 5) AND length(keyword) > 5 AND length(keyword) <= 10
| SORT id ASC
| LIMIT 20
""", Arrays.toString(queryVector));

try (var resp = run(query)) {
// No added columns
assertThat(resp.columns().size(), equalTo(5));
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
assertEquals(5, valuesList.size());
}
}

@Before
public void setup() throws IOException {
assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());

var indexName = "test";
var client = client().admin().indices();
Expand All @@ -176,6 +241,9 @@ public void setup() throws IOException {
.startObject("floats")
.field("type", "float")
.endObject()
.startObject("keyword")
.field("type", "keyword")
.endObject()
.endObject()
.endObject();

Expand All @@ -195,7 +263,8 @@ public void setup() throws IOException {
for (int j = 0; j < numDims; j++) {
vector.add(value++);
}
docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "floats", vector, "vector", vector);
docs[i] = prepareIndex("test").setId("" + i)
.setSource("id", String.valueOf(i), "floats", vector, "vector", vector, "keyword", randomAlphaOfLength(i));
indexedVectors.put(i, vector);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1223,7 +1223,7 @@ public enum Cap {
/**
* Support knn function
*/
KNN_FUNCTION_V3(Build.current().isSnapshot()),
KNN_FUNCTION_V4(Build.current().isSnapshot()),

/**
* Support for the LIKE operator with a list of wildcards.
Expand Down
Loading