diff --git a/docs/changelog/131641.yaml b/docs/changelog/131641.yaml new file mode 100644 index 0000000000000..7d86eed413611 --- /dev/null +++ b/docs/changelog/131641.yaml @@ -0,0 +1,5 @@ +pr: 131641 +summary: Add exception for perform embedding inference requests with query provided +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index 64957328d48dd..c23996a3ce87a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -273,6 +273,14 @@ public ActionRequestValidationException validate() { } } + if (taskType.equals(TaskType.TEXT_EMBEDDING) || taskType.equals(TaskType.SPARSE_EMBEDDING)) { + if (query != null) { + var e = new ActionRequestValidationException(); + e.addValidationError(format("Field [query] cannot be specified for task type [%s]", taskType)); + return e; + } + } + if (taskType.equals(TaskType.TEXT_EMBEDDING) == false && taskType.equals(TaskType.ANY) == false && (inputType != null && InputType.isInternalTypeOrUnspecified(inputType) == false)) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index 2e2b9bf9b0d23..696b45117497a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -191,6 +191,24 @@ public void testValidation_TextEmbedding_WithTopN() { assertThat(inputError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [text_embedding];")); } + public void testValidation_TextEmbedding_WithQuery() { + InferenceAction.Request queryRequest = new InferenceAction.Request( + TaskType.TEXT_EMBEDDING, + "model", + "query", + null, + null, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException queryError = queryRequest.validate(); + assertNotNull(queryError); + assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [query] cannot be specified for task type [text_embedding];")); + } + public void testValidation_Rerank_Null() { InferenceAction.Request queryNullRequest = new InferenceAction.Request( TaskType.RERANK, @@ -249,7 +267,7 @@ public void testValidation_SparseEmbedding_WithInputType() { InferenceAction.Request queryRequest = new InferenceAction.Request( TaskType.SPARSE_EMBEDDING, "model", - "", + null, null, null, List.of("input"), @@ -309,6 +327,27 @@ public void testValidation_SparseEmbedding_WithTopN() { ); } + public void testValidation_SparseEmbedding_WithQuery() { + InferenceAction.Request queryRequest = new InferenceAction.Request( + TaskType.SPARSE_EMBEDDING, + "model", + "query", + null, + null, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException queryError = queryRequest.validate(); + assertNotNull(queryError); + assertThat( + queryError.getMessage(), + is("Validation Failed: 1: Field [query] cannot be specified for task type [sparse_embedding];") + ); + } + public void testValidation_Completion_WithInputType() { InferenceAction.Request queryRequest = new InferenceAction.Request( TaskType.COMPLETION,