Skip to content

[Inference Timeout] Supply inference context to all third party services #131251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ private ElasticInferenceService createElasticInferenceService() {
createWithEmptySettings(threadPool),
ElasticInferenceServiceSettingsTests.create(gatewayUrl),
modelRegistry,
new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool)
new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool),
mockClusterServiceEmpty()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ public Collection<?> createComponents(PluginServices services) {
serviceComponents.get(),
inferenceServiceSettings,
modelRegistry.get(),
authorizationHandler
authorizationHandler,
context
),
context -> new SageMakerService(
new SageMakerModelBuilder(sageMakerSchemas),
Expand All @@ -321,7 +322,8 @@ public Collection<?> createComponents(PluginServices services) {
),
sageMakerSchemas,
services.threadPool(),
sageMakerConfigurations::getOrCompute
sageMakerConfigurations::getOrCompute,
context
)
)
);
Expand Down Expand Up @@ -383,24 +385,24 @@ public void loadExtensions(ExtensionLoader loader) {

public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
return List.of(
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()),
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()),
context -> new OpenAiService(httpFactory.get(), serviceComponents.get()),
context -> new CohereService(httpFactory.get(), serviceComponents.get()),
context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()),
context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()),
context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()),
context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get()),
context -> new MistralService(httpFactory.get(), serviceComponents.get()),
context -> new AnthropicService(httpFactory.get(), serviceComponents.get()),
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()),
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get(), context),
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get(), context),
context -> new OpenAiService(httpFactory.get(), serviceComponents.get(), context),
context -> new CohereService(httpFactory.get(), serviceComponents.get(), context),
context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get(), context),
context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get(), context),
context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get(), context),
context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get(), context),
context -> new MistralService(httpFactory.get(), serviceComponents.get(), context),
context -> new AnthropicService(httpFactory.get(), serviceComponents.get(), context),
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get(), context),
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get(), context),
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get(), context),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get(), context),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get(), context),
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context),
ElasticsearchInternalService::new,
context -> new CustomService(httpFactory.get(), serviceComponents.get())
context -> new CustomService(httpFactory.get(), serviceComponents.get(), context)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.Nullable;
Expand Down Expand Up @@ -42,11 +43,13 @@ public abstract class SenderService implements InferenceService {
protected static final Set<TaskType> COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION);
private final Sender sender;
private final ServiceComponents serviceComponents;
private final ClusterService clusterService;

public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
Objects.requireNonNull(factory);
sender = factory.createSender();
this.serviceComponents = Objects.requireNonNull(serviceComponents);
this.clusterService = Objects.requireNonNull(clusterService);
}

public Sender getSender() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
Expand All @@ -19,6 +20,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -85,8 +87,20 @@ public class AlibabaCloudSearchService extends SenderService {
InputType.INTERNAL_SEARCH
);

public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public AlibabaCloudSearchService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
this(factory, serviceComponents, context.clusterService());
}

public AlibabaCloudSearchService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
ClusterService clusterService
) {
super(factory, serviceComponents, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
Expand All @@ -20,6 +21,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -93,9 +95,19 @@ public class AmazonBedrockService extends SenderService {
public AmazonBedrockService(
HttpRequestSender.Factory httpSenderFactory,
AmazonBedrockRequestSender.Factory amazonBedrockFactory,
ServiceComponents serviceComponents
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(httpSenderFactory, serviceComponents);
this(httpSenderFactory, amazonBedrockFactory, serviceComponents, context.clusterService());
}

public AmazonBedrockService(
HttpRequestSender.Factory httpSenderFactory,
AmazonBedrockRequestSender.Factory amazonBedrockFactory,
ServiceComponents serviceComponents,
ClusterService clusterService
) {
super(httpSenderFactory, serviceComponents, clusterService);
this.amazonBedrockSender = amazonBedrockFactory.createSender();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -58,8 +60,16 @@ public class AnthropicService extends SenderService {

private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.COMPLETION);

public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public AnthropicService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
this(factory, serviceComponents, context.clusterService());
}

public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
super(factory, serviceComponents, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
Expand All @@ -19,6 +20,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -84,8 +86,16 @@ public class AzureAiStudioService extends SenderService {
InputType.INTERNAL_SEARCH
);

public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public AzureAiStudioService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
this(factory, serviceComponents, context.clusterService());
}

public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
super(factory, serviceComponents, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -69,8 +71,16 @@ public class AzureOpenAiService extends SenderService {
private static final String SERVICE_NAME = "Azure OpenAI";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);

public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public AzureOpenAiService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
this(factory, serviceComponents, context.clusterService());
}

public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
super(factory, serviceComponents, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
Expand All @@ -19,6 +20,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -84,8 +86,16 @@ public class CohereService extends SenderService {
// The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated
// on every request

public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public CohereService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
this(factory, serviceComponents, context.clusterService());
}

public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
super(factory, serviceComponents, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
Expand All @@ -19,6 +20,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -74,8 +76,16 @@ public class CustomService extends SenderService {
TaskType.COMPLETION
);

public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public CustomService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
this(factory, serviceComponents, context.clusterService());
}

public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
super(factory, serviceComponents, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -58,8 +60,16 @@ public class DeepSeekService extends SenderService {
);
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES_FOR_STREAMING = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);

public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public DeepSeekService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
this(factory, serviceComponents, context.clusterService());
}

public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
super(factory, serviceComponents, clusterService);
}

@Override
Expand Down
Loading