Skip to content

Add Mistral AI Chat Completion support to Inference Plugin #128538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f7dc246
Add Mistral AI Chat Completion support to Inference Plugin
Jan-Kazlouski-elastic May 27, 2025
0aa8da8
Add changelog file
Jan-Kazlouski-elastic May 27, 2025
c3a8716
Fix tests and typos
Jan-Kazlouski-elastic May 27, 2025
69f16b3
Merge remote-tracking branch 'refs/remotes/origin/main' into feature/…
Jan-Kazlouski-elastic May 29, 2025
91f8ccf
Refactor Mistral chat completion integration and add tests
Jan-Kazlouski-elastic May 29, 2025
ff81e36
Refactor Mistral error response handling and extract StreamingErrorRe…
Jan-Kazlouski-elastic May 29, 2025
5a9ce48
Add Mistral chat completion request and response tests
Jan-Kazlouski-elastic May 30, 2025
17dead3
Enhance error response documentation and clarify StreamingErrorRespon…
Jan-Kazlouski-elastic Jun 1, 2025
74b3df6
Refactor Mistral chat completion request handling and introduce skip …
Jan-Kazlouski-elastic Jun 1, 2025
d50bc76
Refactor MistralChatCompletionServiceSettings to include rateLimitSet…
Jan-Kazlouski-elastic Jun 1, 2025
4824f12
Enhance MistralErrorResponse documentation with detailed error examples
Jan-Kazlouski-elastic Jun 1, 2025
158622e
Add comment for Mistral-specific 422 validation error in OpenAiRespon…
Jan-Kazlouski-elastic Jun 1, 2025
60df2f7
Merge remote-tracking branch 'origin/main' into feature/mistral-chat-…
Jan-Kazlouski-elastic Jun 1, 2025
34ca847
[CI] Auto commit changes from spotless
Jun 2, 2025
cc13241
Merge remote-tracking branch 'origin/main' into feature/mistral-chat-…
Jan-Kazlouski-elastic Jun 2, 2025
24c52e8
Refactor OpenAiUnifiedChatCompletionRequestEntity to remove unused fi…
Jan-Kazlouski-elastic Jun 2, 2025
f184fc7
Refactor UnifiedChatCompletionRequestEntity and UnifiedCompletionRequ…
Jan-Kazlouski-elastic Jun 2, 2025
5cc7402
Refactor MistralChatCompletionRequestEntityTests to improve JSON asse…
Jan-Kazlouski-elastic Jun 2, 2025
977bfc4
Add unit tests for MistralUnifiedChatCompletionResponseHandler to val…
Jan-Kazlouski-elastic Jun 4, 2025
f49fac2
Add unit tests for MistralService
Jan-Kazlouski-elastic Jun 4, 2025
7505915
Merge remote-tracking branch 'origin/main' into feature/mistral-chat-…
Jan-Kazlouski-elastic Jun 4, 2025
fb2be46
Merge remote-tracking branch 'origin/main' into feature/mistral-chat-…
Jan-Kazlouski-elastic Jun 4, 2025
68a5432
Update expected service count in testGetServicesWithCompletionTaskType
Jan-Kazlouski-elastic Jun 4, 2025
102da20
Merge remote-tracking branch 'origin/main' into feature/mistral-chat-…
Jan-Kazlouski-elastic Jun 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/changelog/128538.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pr: 128538
summary: "[ML] Add Mistral Chat Completion support to the Inference Plugin"
summary: "Added Mistral Chat Completion support to the Inference Plugin"
area: Machine Learning
type: enhancement
issues: []
10 changes: 8 additions & 2 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ static TransportVersion def(int id) {
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION_8_19 = def(8_841_0_37);
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_38);
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19 = def(8_841_0_38);
public static final TransportVersion INFERENCE_CUSTOM_SERVICE_ADDED_8_19 = def(8_841_0_39);
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_40);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -268,7 +270,11 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00);
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_DRY_RUN = def(9_081_0_00);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION = def(9_082_0_00);
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED = def(9_083_0_00);
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED = def(9_083_0_00);
public static final TransportVersion INFERENCE_CUSTOM_SERVICE_ADDED = def(9_084_0_00);
public static final TransportVersion ESQL_LIMIT_ROW_SIZE = def(9_085_0_00);
public static final TransportVersion ESQL_REGEX_MATCH_WITH_CASE_INSENSITIVITY = def(9_086_0_00);
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED = def(9_087_0_00);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ public record UnifiedCompletionRequest(
* {@link #MAX_COMPLETION_TOKENS_FIELD}. Providers are expected to pass in their supported field name.
*/
private static final String MAX_TOKENS_PARAM = "max_tokens_field";
/**
* Some providers don't support the stream_options field.
* This parameter is used to skip the stream_options field in the JSON output.
*/
public static final String SKIP_STREAM_OPTIONS_PARAM = "skip_stream_options";

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
Expand All @@ -91,6 +96,23 @@ public static Params withMaxTokens(String modelId, Params params) {
);
}

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MODEL_FIELD}, Value: modelId
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #MAX_TOKENS_FIELD}
* - Key: {@link #SKIP_STREAM_OPTIONS_PARAM}, Value: "true"
*/
public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Params params) {
return new DelegatingMapParams(
Map.ofEntries(
Map.entry(MODEL_ID_PARAM, modelId),
Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD),
Map.entry(SKIP_STREAM_OPTIONS_PARAM, Boolean.TRUE.toString())
),
params
);
}

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MODEL_FIELD}, Value: modelId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,16 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"openai",
"streaming_completion_test_service",
"hugging_face",
"amazon_sagemaker"
"amazon_sagemaker",
"mistral"
).toArray()
)
);
}

public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(7));
assertThat(services.size(), equalTo(8));

var providers = providers(services);

Expand All @@ -176,7 +177,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
"streaming_completion_test_service",
"hugging_face",
"amazon_sagemaker",
"googlevertexai"
"googlevertexai",
"mistral"
).toArray()
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
* A pattern is emerging in how external providers provide error responses.
*
* At a minimum, these return:
* <pre><code>
* {
* "error: {
* "message": "(error message)"
* }
* }
*
* </code></pre>
* Others may return additional information such as error codes specific to the service.
*
* This currently covers error handling for Azure AI Studio, however this pattern
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.response.streaming;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;

import java.util.Objects;
import java.util.Optional;

/**
* Represents an error response from a streaming inference service.
* This class extends {@link ErrorResponse} and provides additional fields
* specific to streaming errors, such as code, param, and type.
* An example error response for a streaming service might look like:
* <pre><code>
* {
* "error": {
* "message": "Invalid input",
* "code": "400",
* "param": "input",
* "type": "invalid_request_error"
* }
* }
* </code></pre>
* TODO: {@link ErrorMessageResponseEntity} is nearly identical to this, but doesn't parse as many fields. We must remove the duplication.
*/
public class StreamingErrorResponse extends ErrorResponse {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment with an example error message that this would parse? Let's also add a TODO to note that ErrorMessageResponseEntity https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntity.java is nearly identical (doesn't parse as many fields) and we should remove the duplication

Copy link
Contributor Author

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic Jun 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
"streaming_error",
true,
args -> Optional.ofNullable((StreamingErrorResponse) args[0])
);
private static final ConstructingObjectParser<StreamingErrorResponse, Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>(
"streaming_error",
true,
args -> new StreamingErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3])
);

static {
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message"));
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("code"));
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("param"));
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type"));

ERROR_PARSER.declareObjectOrNull(
ConstructingObjectParser.optionalConstructorArg(),
ERROR_BODY_PARSER,
null,
new ParseField("error")
);
}

/**
* Standard error response parser. This can be overridden for those subclasses that
* have a different error response structure.
* @param response The error response as an HttpResult
*/
public static ErrorResponse fromResponse(HttpResult response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response.body())
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
// swallow the error
}

return ErrorResponse.UNDEFINED_ERROR;
}

/**
* Standard error response parser. This can be overridden for those subclasses that
* have a different error response structure.
* @param response The error response as a string
*/
public static ErrorResponse fromString(String response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response)
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
// swallow the error
}

return ErrorResponse.UNDEFINED_ERROR;
}

@Nullable
private final String code;
@Nullable
private final String param;
private final String type;

StreamingErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) {
super(errorMessage);
this.code = code;
this.param = param;
this.type = Objects.requireNonNull(type);
}

@Nullable
public String code() {
return code;
}

@Nullable
public String param() {
return param;
}

public String type() {
return type;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import java.io.IOException;
import java.util.Objects;

import static org.elasticsearch.inference.UnifiedCompletionRequest.SKIP_STREAM_OPTIONS_PARAM;

/**
* Represents a unified chat completion request entity.
* This class is used to convert the unified chat input into a format that can be serialized to XContent.
Expand Down Expand Up @@ -46,20 +48,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);

builder.field(STREAM_FIELD, stream);
if (stream) {
fillStreamOptionsFields(builder);
// If request is streamed and skip stream options parameter is not true, include stream options in the request.
if (stream == true && params.paramAsBoolean(SKIP_STREAM_OPTIONS_PARAM, false) == false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: How about we reverse the naming, skip and false seems closer to a double negative to me so maybe:

if (stream && params.paramAsBoolean(INCLUDE_STREAM_OPTIONS_PARAM, true) == true) {

Copy link
Contributor Author

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good thinking. Defaulting boolean to true allows us not to fill it out for every other provider. Took another look at the CONTRIBUTING.md. According to it we should use == check for boolean values only in case we're checking for "false". So I replaced it with:
stream && params.paramAsBoolean(INCLUDE_STREAM_OPTIONS_PARAM, true)
Also extended the javadoc for INCLUDE_STREAM_OPTIONS_PARAM

builder.startObject(STREAM_OPTIONS_FIELD);
builder.field(INCLUDE_USAGE_FIELD, true);
builder.endObject();
}

return builder;
}

/**
* This method is used to fill the stream options fields in the request entity.
* It is called when the stream option is set to true.
*/
protected void fillStreamOptionsFields(XContentBuilder builder) throws IOException {
builder.startObject(STREAM_OPTIONS_FIELD);
builder.field(INCLUDE_USAGE_FIELD, true);
builder.endObject();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.mistral;

import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponse;
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;

/**
* Handles non-streaming completion responses for Mistral models, extending the OpenAI completion response handler.
* This class is specifically designed to handle Mistral's error response format.
*/
public class MistralCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {

/**
* Constructs a MistralCompletionResponseHandler with the specified request type and response parser.
*
* @param requestType The type of request being handled (e.g., "mistral completions").
* @param parseFunction The function to parse the response.
*/
public MistralCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, MistralErrorResponse::fromResponse);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,15 @@ protected void doInfer(
var actionCreator = new MistralActionCreator(getSender(), getServiceComponents());

switch (model) {
case MistralEmbeddingsModel mistralEmbeddingsModel -> {
var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings);
action.execute(inputs, timeout, listener);
}
case MistralChatCompletionModel mistralChatCompletionModel -> {
var action = mistralChatCompletionModel.accept(actionCreator);
action.execute(inputs, timeout, listener);
}
default -> listener.onFailure(createInvalidModelException(model));
case MistralEmbeddingsModel mistralEmbeddingsModel:
mistralEmbeddingsModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener);
break;
case MistralChatCompletionModel mistralChatCompletionModel:
mistralChatCompletionModel.accept(actionCreator).execute(inputs, timeout, listener);
break;
default:
listener.onFailure(createInvalidModelException(model));
break;
}
}

Expand Down Expand Up @@ -292,27 +292,23 @@ private static MistralModel createModel(
String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
case TEXT_EMBEDDING -> new MistralEmbeddingsModel(
modelId,
taskType,
NAME,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
context
);
case CHAT_COMPLETION, COMPLETION -> new MistralChatCompletionModel(
modelId,
taskType,
NAME,
serviceSettings,
secretSettings,
context
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
switch (taskType) {
case TEXT_EMBEDDING:
return new MistralEmbeddingsModel(
modelId,
taskType,
NAME,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
context
);
case CHAT_COMPLETION, COMPLETION:
return new MistralChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context);
default:
throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
}
}

private MistralModel createModelFromPersistent(
Expand Down
Loading
Loading