-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Add Mistral AI Chat Completion support to Inference Plugin #128538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
f7dc246
0aa8da8
c3a8716
69f16b3
91f8ccf
ff81e36
5a9ce48
17dead3
74b3df6
d50bc76
4824f12
158622e
60df2f7
34ca847
cc13241
24c52e8
f184fc7
5cc7402
977bfc4
f49fac2
7505915
fb2be46
68a5432
102da20
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
pr: 128538 | ||
summary: "[ML] Add Mistral Chat Completion support to the Inference Plugin" | ||
summary: "Added Mistral Chat Completion support to the Inference Plugin" | ||
area: Machine Learning | ||
type: enhancement | ||
issues: [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.external.response.streaming; | ||
|
||
import org.elasticsearch.core.Nullable; | ||
import org.elasticsearch.xcontent.ConstructingObjectParser; | ||
import org.elasticsearch.xcontent.ParseField; | ||
import org.elasticsearch.xcontent.XContentFactory; | ||
import org.elasticsearch.xcontent.XContentParser; | ||
import org.elasticsearch.xcontent.XContentParserConfiguration; | ||
import org.elasticsearch.xcontent.XContentType; | ||
import org.elasticsearch.xpack.inference.external.http.HttpResult; | ||
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; | ||
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity; | ||
|
||
import java.util.Objects; | ||
import java.util.Optional; | ||
|
||
/** | ||
* Represents an error response from a streaming inference service. | ||
* This class extends {@link ErrorResponse} and provides additional fields | ||
* specific to streaming errors, such as code, param, and type. | ||
* An example error response for a streaming service might look like: | ||
* <pre><code> | ||
* { | ||
* "error": { | ||
* "message": "Invalid input", | ||
* "code": "400", | ||
* "param": "input", | ||
* "type": "invalid_request_error" | ||
* } | ||
* } | ||
* </code></pre> | ||
* TODO: {@link ErrorMessageResponseEntity} is nearly identical to this, but doesn't parse as many fields. We must remove the duplication. | ||
*/ | ||
public class StreamingErrorResponse extends ErrorResponse { | ||
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>( | ||
"streaming_error", | ||
true, | ||
args -> Optional.ofNullable((StreamingErrorResponse) args[0]) | ||
); | ||
private static final ConstructingObjectParser<StreamingErrorResponse, Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>( | ||
"streaming_error", | ||
true, | ||
args -> new StreamingErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3]) | ||
); | ||
|
||
static { | ||
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message")); | ||
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("code")); | ||
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("param")); | ||
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type")); | ||
|
||
ERROR_PARSER.declareObjectOrNull( | ||
ConstructingObjectParser.optionalConstructorArg(), | ||
ERROR_BODY_PARSER, | ||
null, | ||
new ParseField("error") | ||
); | ||
} | ||
|
||
/** | ||
* Standard error response parser. This can be overridden for those subclasses that | ||
* have a different error response structure. | ||
* @param response The error response as an HttpResult | ||
*/ | ||
public static ErrorResponse fromResponse(HttpResult response) { | ||
try ( | ||
XContentParser parser = XContentFactory.xContent(XContentType.JSON) | ||
.createParser(XContentParserConfiguration.EMPTY, response.body()) | ||
) { | ||
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); | ||
} catch (Exception e) { | ||
// swallow the error | ||
} | ||
|
||
return ErrorResponse.UNDEFINED_ERROR; | ||
} | ||
|
||
/** | ||
* Standard error response parser. This can be overridden for those subclasses that | ||
* have a different error response structure. | ||
* @param response The error response as a string | ||
*/ | ||
public static ErrorResponse fromString(String response) { | ||
try ( | ||
XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response) | ||
) { | ||
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); | ||
} catch (Exception e) { | ||
// swallow the error | ||
} | ||
|
||
return ErrorResponse.UNDEFINED_ERROR; | ||
} | ||
|
||
@Nullable | ||
private final String code; | ||
@Nullable | ||
private final String param; | ||
private final String type; | ||
|
||
StreamingErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) { | ||
super(errorMessage); | ||
this.code = code; | ||
this.param = param; | ||
this.type = Objects.requireNonNull(type); | ||
} | ||
|
||
@Nullable | ||
public String code() { | ||
return code; | ||
} | ||
|
||
@Nullable | ||
public String param() { | ||
return param; | ||
} | ||
|
||
public String type() { | ||
return type; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,8 @@ | |
import java.io.IOException; | ||
import java.util.Objects; | ||
|
||
import static org.elasticsearch.inference.UnifiedCompletionRequest.SKIP_STREAM_OPTIONS_PARAM; | ||
|
||
/** | ||
* Represents a unified chat completion request entity. | ||
* This class is used to convert the unified chat input into a format that can be serialized to XContent. | ||
|
@@ -46,20 +48,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws | |
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); | ||
|
||
builder.field(STREAM_FIELD, stream); | ||
if (stream) { | ||
fillStreamOptionsFields(builder); | ||
// If request is streamed and skip stream options parameter is not true, include stream options in the request. | ||
if (stream == true && params.paramAsBoolean(SKIP_STREAM_OPTIONS_PARAM, false) == false) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: How about we reverse the naming, skip and false seems closer to a double negative to me so maybe:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good thinking. Defaulting boolean to true allows us not to fill it out for every other provider. Took another look at the CONTRIBUTING.md. According to it we should use == check for boolean values only in case we're checking for "false". So I replaced it with: |
||
builder.startObject(STREAM_OPTIONS_FIELD); | ||
builder.field(INCLUDE_USAGE_FIELD, true); | ||
builder.endObject(); | ||
} | ||
|
||
return builder; | ||
} | ||
|
||
/** | ||
* This method is used to fill the stream options fields in the request entity. | ||
* It is called when the stream option is set to true. | ||
*/ | ||
protected void fillStreamOptionsFields(XContentBuilder builder) throws IOException { | ||
builder.startObject(STREAM_OPTIONS_FIELD); | ||
builder.field(INCLUDE_USAGE_FIELD, true); | ||
builder.endObject(); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.services.mistral; | ||
|
||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; | ||
import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponse; | ||
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; | ||
|
||
/** | ||
* Handles non-streaming completion responses for Mistral models, extending the OpenAI completion response handler. | ||
* This class is specifically designed to handle Mistral's error response format. | ||
*/ | ||
public class MistralCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { | ||
|
||
/** | ||
* Constructs a MistralCompletionResponseHandler with the specified request type and response parser. | ||
* | ||
* @param requestType The type of request being handled (e.g., "mistral completions"). | ||
* @param parseFunction The function to parse the response. | ||
*/ | ||
public MistralCompletionResponseHandler(String requestType, ResponseParser parseFunction) { | ||
super(requestType, parseFunction, MistralErrorResponse::fromResponse); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment with an example error message that this would parse? Let's also add a TODO to note that
ErrorMessageResponseEntity
https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntity.java is nearly identical (doesn't parse as many fields) and we should remove the duplicationUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.