Skip to content

[Inference Timeout] Supply inference context to all third party services #131251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ private ElasticInferenceService createElasticInferenceService() {
createWithEmptySettings(threadPool),
ElasticInferenceServiceSettingsTests.create(gatewayUrl),
modelRegistry,
new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool)
new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool),
mockClusterServiceEmpty()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ public Collection<?> createComponents(PluginServices services) {
serviceComponents.get(),
inferenceServiceSettings,
modelRegistry.get(),
authorizationHandler
authorizationHandler,
context
),
context -> new SageMakerService(
new SageMakerModelBuilder(sageMakerSchemas),
Expand All @@ -321,7 +322,8 @@ public Collection<?> createComponents(PluginServices services) {
),
sageMakerSchemas,
services.threadPool(),
sageMakerConfigurations::getOrCompute
sageMakerConfigurations::getOrCompute,
context
)
)
);
Expand Down Expand Up @@ -383,24 +385,24 @@ public void loadExtensions(ExtensionLoader loader) {

public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
return List.of(
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()),
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()),
context -> new OpenAiService(httpFactory.get(), serviceComponents.get()),
context -> new CohereService(httpFactory.get(), serviceComponents.get()),
context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()),
context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()),
context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()),
context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get()),
context -> new MistralService(httpFactory.get(), serviceComponents.get()),
context -> new AnthropicService(httpFactory.get(), serviceComponents.get()),
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()),
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get(), context),
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get(), context),
context -> new OpenAiService(httpFactory.get(), serviceComponents.get(), context),
context -> new CohereService(httpFactory.get(), serviceComponents.get(), context),
context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get(), context),
context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get(), context),
context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get(), context),
context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get(), context),
context -> new MistralService(httpFactory.get(), serviceComponents.get(), context),
context -> new AnthropicService(httpFactory.get(), serviceComponents.get(), context),
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get(), context),
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get(), context),
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get(), context),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get(), context),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get(), context),
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context),
ElasticsearchInternalService::new,
context -> new CustomService(httpFactory.get(), serviceComponents.get())
context -> new CustomService(httpFactory.get(), serviceComponents.get(), context)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.Nullable;
Expand All @@ -17,6 +18,7 @@
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand All @@ -42,11 +44,26 @@ public abstract class SenderService implements InferenceService {
protected static final Set<TaskType> COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION);
private final Sender sender;
private final ServiceComponents serviceComponents;
private final ClusterService clusterService;

public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
public SenderService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
Objects.requireNonNull(factory);
sender = factory.createSender();
this.serviceComponents = Objects.requireNonNull(serviceComponents);
this.clusterService = Objects.requireNonNull(context.clusterService());

}

// for testing
public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
Objects.requireNonNull(factory);
sender = factory.createSender();
this.serviceComponents = Objects.requireNonNull(serviceComponents);
this.clusterService = clusterService;
}

public Sender getSender() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
Expand All @@ -19,6 +20,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -85,8 +87,21 @@ public class AlibabaCloudSearchService extends SenderService {
InputType.INTERNAL_SEARCH
);

public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public AlibabaCloudSearchService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(factory, serviceComponents, context);
}

// for testing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: I don't think we necessarily need this comment in all classes inheriting this constructor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added these comments to stay consistent with the existing approach in classes like BaseElasticsearchInternalService and ElasticsearchInternalService

Just to clarify, are you suggesting we remove these comments entirely from all the classes, or only from constructors that are simply forwarding to super without additional logic?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to keep them in, that's OK - it was a nitpick, but I thought they were potentially unnecessary. Non blocking comment though. 🙂

public AlibabaCloudSearchService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
ClusterService clusterService
) {
super(factory, serviceComponents, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
Expand All @@ -20,6 +21,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -93,9 +95,21 @@ public class AmazonBedrockService extends SenderService {
public AmazonBedrockService(
HttpRequestSender.Factory httpSenderFactory,
AmazonBedrockRequestSender.Factory amazonBedrockFactory,
ServiceComponents serviceComponents
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(httpSenderFactory, serviceComponents);
super(httpSenderFactory, serviceComponents, context);
this.amazonBedrockSender = amazonBedrockFactory.createSender();
}

// for testing
public AmazonBedrockService(
HttpRequestSender.Factory httpSenderFactory,
AmazonBedrockRequestSender.Factory amazonBedrockFactory,
ServiceComponents serviceComponents,
ClusterService clusterService
) {
super(httpSenderFactory, serviceComponents, clusterService);
this.amazonBedrockSender = amazonBedrockFactory.createSender();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -58,8 +60,17 @@ public class AnthropicService extends SenderService {

private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.COMPLETION);

public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public AnthropicService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(factory, serviceComponents, context);
}

// for testing
public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
super(factory, serviceComponents, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
Expand All @@ -19,6 +20,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -83,8 +85,17 @@ public class AzureAiStudioService extends SenderService {
InputType.INTERNAL_SEARCH
);

public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public AzureAiStudioService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(factory, serviceComponents, context);
}

// for testing
public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
super(factory, serviceComponents, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -69,8 +71,17 @@ public class AzureOpenAiService extends SenderService {
private static final String SERVICE_NAME = "Azure OpenAI";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);

public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public AzureOpenAiService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(factory, serviceComponents, context);
}

// for testing
public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
super(factory, serviceComponents, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
Expand All @@ -19,6 +20,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -84,8 +86,17 @@ public class CohereService extends SenderService {
// The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated
// on every request

public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public CohereService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(factory, serviceComponents, context);
}

// For testing
public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
super(factory, serviceComponents, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
Expand All @@ -19,6 +20,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -74,8 +76,17 @@ public class CustomService extends SenderService {
TaskType.COMPLETION
);

public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public CustomService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(factory, serviceComponents, context);
}

// for testing
public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
super(factory, serviceComponents, clusterService);
}

@Override
Expand Down
Loading