From 29c74743d52c9e1b22b7c376ba49b7a4058bde03 Mon Sep 17 00:00:00 2001 From: Samiul Monir Date: Mon, 14 Jul 2025 13:54:35 -0400 Subject: [PATCH 1/9] Refactoring inference services to accept context --- .../xpack/inference/InferencePlugin.java | 40 ++++++++++--------- .../inference/services/SenderService.java | 15 ++++++- .../AlibabaCloudSearchService.java | 11 ++++- .../amazonbedrock/AmazonBedrockService.java | 18 ++++++++- .../services/anthropic/AnthropicService.java | 11 ++++- .../azureaistudio/AzureAiStudioService.java | 11 ++++- .../azureopenai/AzureOpenAiService.java | 11 ++++- .../services/cohere/CohereService.java | 11 ++++- .../services/custom/CustomService.java | 11 ++++- .../services/deepseek/DeepSeekService.java | 11 ++++- .../elastic/ElasticInferenceService.java | 32 ++++++++++++++- .../googleaistudio/GoogleAiStudioService.java | 11 ++++- .../googlevertexai/GoogleVertexAiService.java | 11 ++++- .../huggingface/HuggingFaceBaseService.java | 11 ++++- .../huggingface/HuggingFaceService.java | 11 ++++- .../elser/HuggingFaceElserService.java | 11 ++++- .../ibmwatsonx/IbmWatsonxService.java | 11 ++++- .../services/jinaai/JinaAIService.java | 11 ++++- .../services/mistral/MistralService.java | 11 ++++- .../services/openai/OpenAiService.java | 11 ++++- .../services/sagemaker/SageMakerService.java | 31 +++++++++++++- .../services/voyageai/VoyageAIService.java | 11 ++++- 22 files changed, 264 insertions(+), 59 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index de31f9d6cefc8..b729857c91f81 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -311,7 +311,8 @@ public Collection createComponents(PluginServices services) { serviceComponents.get(), inferenceServiceSettings, modelRegistry.get(), - authorizationHandler + authorizationHandler, + context ), context -> new SageMakerService( new SageMakerModelBuilder(sageMakerSchemas), @@ -321,7 +322,8 @@ public Collection createComponents(PluginServices services) { ), sageMakerSchemas, services.threadPool(), - sageMakerConfigurations::getOrCompute + sageMakerConfigurations::getOrCompute, + context ) ) ); @@ -383,24 +385,24 @@ public void loadExtensions(ExtensionLoader loader) { public List getInferenceServiceFactories() { return List.of( - context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()), - context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()), - context -> new OpenAiService(httpFactory.get(), serviceComponents.get()), - context -> new CohereService(httpFactory.get(), serviceComponents.get()), - context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()), - context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()), - context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()), - context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get()), - context -> new MistralService(httpFactory.get(), serviceComponents.get()), - context -> new AnthropicService(httpFactory.get(), serviceComponents.get()), - context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()), - context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()), - context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()), - context -> new JinaAIService(httpFactory.get(), serviceComponents.get()), - context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()), - context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()), + context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get(), context), + context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get(), context), + context -> new OpenAiService(httpFactory.get(), serviceComponents.get(), context), + context -> new CohereService(httpFactory.get(), serviceComponents.get(), context), + context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get(), context), + context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get(), context), + context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get(), context), + context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get(), context), + context -> new MistralService(httpFactory.get(), serviceComponents.get(), context), + context -> new AnthropicService(httpFactory.get(), serviceComponents.get(), context), + context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get(), context), + context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get(), context), + context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get(), context), + context -> new JinaAIService(httpFactory.get(), serviceComponents.get(), context), + context -> new VoyageAIService(httpFactory.get(), serviceComponents.get(), context), + context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context), ElasticsearchInternalService::new, - context -> new CustomService(httpFactory.get(), serviceComponents.get()) + context -> new CustomService(httpFactory.get(), serviceComponents.get(), context) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index ff8ae6fd5aac3..2ebf2c4c00fab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; @@ -17,6 +18,7 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -42,11 +44,22 @@ public abstract class SenderService implements InferenceService { protected static final Set COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION); private final Sender sender; private final ServiceComponents serviceComponents; + private final ClusterService clusterService; - public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { Objects.requireNonNull(factory); sender = factory.createSender(); this.serviceComponents = Objects.requireNonNull(serviceComponents); + this.clusterService = Objects.requireNonNull(context.clusterService()); + + } + + // for testing + public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + Objects.requireNonNull(factory); + sender = factory.createSender(); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + this.clusterService = clusterService; } public Sender getSender() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 7897317736c72..94598575ec63c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -85,8 +87,13 @@ public class AlibabaCloudSearchService extends SenderService { InputType.INTERNAL_SEARCH ); - public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // for testing + public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 591607953ea1a..dcb6a715c1942 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -20,6 +21,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -93,9 +95,21 @@ public class AmazonBedrockService extends SenderService { public AmazonBedrockService( HttpRequestSender.Factory httpSenderFactory, AmazonBedrockRequestSender.Factory amazonBedrockFactory, - ServiceComponents serviceComponents + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(httpSenderFactory, serviceComponents); + super(httpSenderFactory, serviceComponents, context); + this.amazonBedrockSender = amazonBedrockFactory.createSender(); + } + + // for testing + public AmazonBedrockService( + HttpRequestSender.Factory httpSenderFactory, + AmazonBedrockRequestSender.Factory amazonBedrockFactory, + ServiceComponents serviceComponents, + ClusterService clusterService + ) { + super(httpSenderFactory, serviceComponents, clusterService); this.amazonBedrockSender = amazonBedrockFactory.createSender(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 791518ccc9168..9ad3c639d6416 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -11,12 +11,14 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -58,8 +60,13 @@ public class AnthropicService extends SenderService { private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.COMPLETION); - public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // for testing + public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 04883f23b947f..564ca98e43ac0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -83,8 +85,13 @@ public class AzureAiStudioService extends SenderService { InputType.INTERNAL_SEARCH ); - public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // for testing + public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index e9ff97c1ba725..0d2718ffa6197 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -69,8 +71,13 @@ public class AzureOpenAiService extends SenderService { private static final String SERVICE_NAME = "Azure OpenAI"; private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION); - public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // for testing + public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index c2f1221763165..be7b263e02c2f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -84,8 +86,13 @@ public class CohereService extends SenderService { // The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated // on every request - public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // For testing + public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 4e81d37ead3ad..12b702cc357bf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -74,8 +76,13 @@ public class CustomService extends SenderService { TaskType.COMPLETION ); - public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // for testing + public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index 56719199e094f..859b3ca702d36 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -10,12 +10,14 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -58,8 +60,13 @@ public class DeepSeekService extends SenderService { ); private static final EnumSet SUPPORTED_TASK_TYPES_FOR_STREAMING = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); - public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // for testing + public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 36712ed922e95..3bac284f4f3cd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -22,6 +23,7 @@ import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.MinimalServiceSettings; @@ -139,9 +141,35 @@ public ElasticInferenceService( ServiceComponents serviceComponents, ElasticInferenceServiceSettings elasticInferenceServiceSettings, ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents); + super(factory, serviceComponents, context); + this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( + elasticInferenceServiceSettings.getElasticInferenceServiceUrl() + ); + authorizationHandler = new ElasticInferenceServiceAuthorizationHandler( + serviceComponents, + modelRegistry, + authorizationRequestHandler, + initDefaultEndpoints(elasticInferenceServiceComponents), + IMPLEMENTED_TASK_TYPES, + this, + getSender(), + elasticInferenceServiceSettings + ); + } + + // for testing + public ElasticInferenceService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + ElasticInferenceServiceSettings elasticInferenceServiceSettings, + ModelRegistry modelRegistry, + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + ClusterService clusterService + ) { + super(factory, serviceComponents, clusterService); this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( elasticInferenceServiceSettings.getElasticInferenceServiceUrl() ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 9841ea64370c3..0541ba39a9c2a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -82,8 +84,13 @@ public class GoogleAiStudioService extends SenderService { InputType.INTERNAL_SEARCH ); - public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // for testing + public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 3b59e999125e5..6de333a176fe1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -97,8 +99,13 @@ public Set supportedStreamingTasks() { return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); } - public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // for testing + public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index b0d40b41914d5..2cf5ee6645b43 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -8,9 +8,11 @@ package org.elasticsearch.xpack.inference.services.huggingface; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -44,8 +46,13 @@ public abstract class HuggingFaceBaseService extends SenderService { */ static final int EMBEDDING_MAX_BATCH_SIZE = 20; - public HuggingFaceBaseService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public HuggingFaceBaseService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // for testing + public HuggingFaceBaseService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index d10fb77290c6b..dd9078053f4df 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -11,10 +11,12 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -71,8 +73,13 @@ public class HuggingFaceService extends HuggingFaceBaseService { OpenAiChatCompletionResponseEntity::fromResponse ); - public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // for testing + public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index e61995aac91f3..721f629cab34e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -11,12 +11,14 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -57,8 +59,13 @@ public class HuggingFaceElserService extends HuggingFaceBaseService { private static final String SERVICE_NAME = "Hugging Face ELSER"; private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING); - public HuggingFaceElserService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public HuggingFaceElserService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // For testing + public HuggingFaceElserService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 9bc63be1f9e7e..d714feb29eb7a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -83,8 +85,13 @@ public class IbmWatsonxService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // for testing + public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index c2e88cb6cdc7c..e062b28c0618b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -76,8 +78,13 @@ public class JinaAIService extends SenderService { InputType.INTERNAL_SEARCH ); - public JinaAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public JinaAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // for testing + public JinaAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index b11feb117d761..8269cf3c9c097 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -84,8 +86,13 @@ public class MistralService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // for testing + public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index edff1dfc08cba..12f8313661b16 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -91,8 +93,13 @@ public class OpenAiService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // for testing + public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java index aafd6c46857fc..27de36fa223d0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java @@ -12,6 +12,7 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -20,6 +21,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -37,6 +39,7 @@ import java.util.EnumSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import static org.elasticsearch.core.Strings.format; @@ -55,13 +58,15 @@ public class SageMakerService implements InferenceService { private final SageMakerSchemas schemas; private final ThreadPool threadPool; private final LazyInitializable configuration; + private final ClusterService clusterService; public SageMakerService( SageMakerModelBuilder modelBuilder, SageMakerClient client, SageMakerSchemas schemas, ThreadPool threadPool, - CheckedSupplier, RuntimeException> configurationMap + CheckedSupplier, RuntimeException> configurationMap, + InferenceServiceExtension.InferenceServiceFactoryContext context ) { this.modelBuilder = modelBuilder; this.client = client; @@ -74,6 +79,30 @@ public SageMakerService( .setConfigurations(configurationMap.get()) .build() ); + this.clusterService = Objects.requireNonNull(context.clusterService()); + } + + // for testing + public SageMakerService( + SageMakerModelBuilder modelBuilder, + SageMakerClient client, + SageMakerSchemas schemas, + ThreadPool threadPool, + CheckedSupplier, RuntimeException> configurationMap, + ClusterService clusterService + ) { + this.modelBuilder = modelBuilder; + this.client = client; + this.schemas = schemas; + this.threadPool = threadPool; + this.configuration = new LazyInitializable<>( + () -> new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(DISPLAY_NAME) + .setTaskTypes(supportedTaskTypes()) + .setConfigurations(configurationMap.get()) + .build() + ); + this.clusterService = clusterService; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 0ffec057dc2b4..9b5563570b436 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -96,8 +98,13 @@ public class VoyageAIService extends SenderService { InputType.INTERNAL_SEARCH ); - public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + super(factory, serviceComponents, context); + } + + // for testing + public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override From 9ccf7490be66104d700b44c61bdd307bf36aa018 Mon Sep 17 00:00:00 2001 From: Samiul Monir Date: Mon, 14 Jul 2025 18:01:44 -0400 Subject: [PATCH 2/9] fix linting issues --- .../inference/services/SenderService.java | 6 +- .../AlibabaCloudSearchService.java | 12 ++- .../services/anthropic/AnthropicService.java | 6 +- .../azureaistudio/AzureAiStudioService.java | 6 +- .../azureopenai/AzureOpenAiService.java | 6 +- .../services/cohere/CohereService.java | 6 +- .../services/custom/CustomService.java | 6 +- .../services/deepseek/DeepSeekService.java | 6 +- .../googleaistudio/GoogleAiStudioService.java | 6 +- .../googlevertexai/GoogleVertexAiService.java | 6 +- .../huggingface/HuggingFaceBaseService.java | 6 +- .../huggingface/HuggingFaceService.java | 6 +- .../elser/HuggingFaceElserService.java | 6 +- .../ibmwatsonx/IbmWatsonxService.java | 6 +- .../services/jinaai/JinaAIService.java | 6 +- .../services/mistral/MistralService.java | 6 +- .../services/openai/OpenAiService.java | 6 +- .../services/voyageai/VoyageAIService.java | 6 +- .../services/SenderServiceTests.java | 10 ++- .../AlibabaCloudSearchServiceTests.java | 76 ++++++++++++++---- .../AmazonBedrockServiceTests.java | 77 ++++++++++++++++--- .../anthropic/AnthropicServiceTests.java | 10 +-- .../AzureAiStudioServiceTests.java | 28 ++++--- .../azureopenai/AzureOpenAiServiceTests.java | 22 +++--- .../services/cohere/CohereServiceTests.java | 28 +++---- .../services/custom/CustomServiceTests.java | 3 +- .../deepseek/DeepSeekServiceTests.java | 3 +- .../elastic/ElasticInferenceServiceTests.java | 9 ++- .../GoogleAiStudioServiceTests.java | 24 +++--- .../GoogleVertexAiServiceTests.java | 2 +- .../HuggingFaceBaseServiceTests.java | 3 +- .../HuggingFaceElserServiceTests.java | 5 +- .../huggingface/HuggingFaceServiceTests.java | 32 ++++---- .../ibmwatsonx/IbmWatsonxServiceTests.java | 8 +- .../services/jinaai/JinaAIServiceTests.java | 32 ++++---- .../services/mistral/MistralServiceTests.java | 24 +++--- .../services/openai/OpenAiServiceTests.java | 26 +++---- .../sagemaker/SageMakerServiceTests.java | 3 +- .../voyageai/VoyageAIServiceTests.java | 32 ++++---- 39 files changed, 389 insertions(+), 182 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 2ebf2c4c00fab..7b7dbb7775623 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -46,7 +46,11 @@ public abstract class SenderService implements InferenceService { private final ServiceComponents serviceComponents; private final ClusterService clusterService; - public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public SenderService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { Objects.requireNonNull(factory); sender = factory.createSender(); this.serviceComponents = Objects.requireNonNull(serviceComponents); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 94598575ec63c..ed08c979c0798 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -87,12 +87,20 @@ public class AlibabaCloudSearchService extends SenderService { InputType.INTERNAL_SEARCH ); - public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public AlibabaCloudSearchService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } // for testing - public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + public AlibabaCloudSearchService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + ClusterService clusterService + ) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 9ad3c639d6416..5215916275bb3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -60,7 +60,11 @@ public class AnthropicService extends SenderService { private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.COMPLETION); - public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public AnthropicService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 564ca98e43ac0..979a0408521bd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -85,7 +85,11 @@ public class AzureAiStudioService extends SenderService { InputType.INTERNAL_SEARCH ); - public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public AzureAiStudioService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 0d2718ffa6197..a0e1675c30938 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -71,7 +71,11 @@ public class AzureOpenAiService extends SenderService { private static final String SERVICE_NAME = "Azure OpenAI"; private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION); - public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public AzureOpenAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index be7b263e02c2f..4d9138304a5a0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -86,7 +86,11 @@ public class CohereService extends SenderService { // The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated // on every request - public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public CohereService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 12b702cc357bf..294b3775b6af7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -76,7 +76,11 @@ public class CustomService extends SenderService { TaskType.COMPLETION ); - public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public CustomService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index 859b3ca702d36..c18ccbe7a2b13 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -60,7 +60,11 @@ public class DeepSeekService extends SenderService { ); private static final EnumSet SUPPORTED_TASK_TYPES_FOR_STREAMING = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); - public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public DeepSeekService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 0541ba39a9c2a..ad0ea5d878b80 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -84,7 +84,11 @@ public class GoogleAiStudioService extends SenderService { InputType.INTERNAL_SEARCH ); - public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public GoogleAiStudioService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 6de333a176fe1..506c65f3191f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -99,7 +99,11 @@ public Set supportedStreamingTasks() { return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); } - public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public GoogleVertexAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index 2cf5ee6645b43..8d01fcab861fb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -46,7 +46,11 @@ public abstract class HuggingFaceBaseService extends SenderService { */ static final int EMBEDDING_MAX_BATCH_SIZE = 20; - public HuggingFaceBaseService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public HuggingFaceBaseService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index dd9078053f4df..63d97f9b1ef52 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -73,7 +73,11 @@ public class HuggingFaceService extends HuggingFaceBaseService { OpenAiChatCompletionResponseEntity::fromResponse ); - public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public HuggingFaceService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 721f629cab34e..ed08a53e0aace 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -59,7 +59,11 @@ public class HuggingFaceElserService extends HuggingFaceBaseService { private static final String SERVICE_NAME = "Hugging Face ELSER"; private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING); - public HuggingFaceElserService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public HuggingFaceElserService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index d714feb29eb7a..3527d14b9b9a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -85,7 +85,11 @@ public class IbmWatsonxService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public IbmWatsonxService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index e062b28c0618b..e1025fa874c45 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -78,7 +78,11 @@ public class JinaAIService extends SenderService { InputType.INTERNAL_SEARCH ); - public JinaAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public JinaAIService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 8269cf3c9c097..790afb766b024 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -86,7 +86,11 @@ public class MistralService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public MistralService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 12f8313661b16..9ee494168bf7d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -93,7 +93,11 @@ public class OpenAiService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public OpenAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 9b5563570b436..f26974e525b43 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -98,7 +98,11 @@ public class VoyageAIService extends SenderService { InputType.INTERNAL_SEARCH ); - public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) { + public VoyageAIService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { super(factory, serviceComponents, context); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 5d7a6a149f941..7457859a64603 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; @@ -36,6 +37,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -64,7 +66,7 @@ public void testStart_InitializesTheSender() throws IOException { var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); - try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { + try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.start(mock(Model.class), listener); @@ -84,7 +86,7 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); - try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { + try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.start(mock(Model.class), listener); listener.actionGet(TIMEOUT); @@ -102,8 +104,8 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep } private static final class TestSenderService extends SenderService { - TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index 8fbbd33d569e4..f0258e9f66ed5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -91,7 +91,13 @@ public void shutdown() throws IOException { } public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -116,7 +122,13 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException } public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -143,7 +155,13 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP } public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -169,7 +187,13 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsN } public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var model = service.parsePersistedConfig( "id", TaskType.TEXT_EMBEDDING, @@ -190,7 +214,13 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting } public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var model = service.parsePersistedConfig( "id", TaskType.TEXT_EMBEDDING, @@ -210,7 +240,13 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting } public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var persistedConfig = getPersistedConfigMap( AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), @@ -235,7 +271,13 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun } public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var persistedConfig = getPersistedConfigMap( AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), @@ -262,7 +304,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createCompletionModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -279,7 +321,7 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO public void testUpdateModelWithEmbeddingDetails_UpdatesEmbeddingSizeAndSimilarity() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = AlibabaCloudSearchEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -316,7 +358,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType_TextEmbedding() t taskSettingsMap, secretSettingsMap ); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -360,7 +402,7 @@ public void testInfer_ThrowsValidationExceptionForInvalidInputType_SparseEmbeddi taskSettingsMap, secretSettingsMap ); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -404,7 +446,7 @@ public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOExc taskSettingsMap, secretSettingsMap ); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -452,7 +494,7 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = createModelForTaskType(taskType, chunkingSettings); PlainActionFuture> listener = new PlainActionFuture<>(); @@ -482,7 +524,13 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin @SuppressWarnings("checkstyle:LineLength") public void testGetConfiguration() throws Exception { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { String content = XContentHelper.stripWhitespace( """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index a014f27e7f0cc..c3b1cab4b4e0a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -959,7 +959,14 @@ public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOExc ); var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1007,7 +1014,12 @@ public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException ); try ( - var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool)); + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender() ) { var results = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))); @@ -1042,7 +1054,14 @@ public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { var results = new TextEmbeddingFloatResults( List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) @@ -1088,7 +1107,14 @@ public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException { mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { var mockResults = new ChatCompletionResults(List.of(new ChatCompletionResults.Result("test result"))); requestSender.enqueue(mockResults); @@ -1132,7 +1158,14 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var model = AmazonBedrockChatCompletionModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -1166,7 +1199,14 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var embeddingSize = randomNonNegativeInt(); var provider = randomFrom(AmazonBedrockProvider.values()); var model = AmazonBedrockEmbeddingsModelTests.createModel( @@ -1205,7 +1245,12 @@ public void testInfer_UnauthorizedResponse() throws IOException { ); try ( - var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool)); + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender() ) { requestSender.enqueue( @@ -1240,7 +1285,7 @@ public void testInfer_UnauthorizedResponse() throws IOException { } public void testSupportsStreaming() throws IOException { - try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()))) { + try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1284,7 +1329,14 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { { var mockResults1 = new TextEmbeddingFloatResults( @@ -1345,7 +1397,12 @@ private AmazonBedrockService createAmazonBedrockService() { ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), mockClusterServiceEmpty() ); - return new AmazonBedrockService(mock(HttpRequestSender.Factory.class), amazonBedrockFactory, createWithEmptySettings(threadPool)); + return new AmazonBedrockService( + mock(HttpRequestSender.Factory.class), + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index a3f0b01901009..9111866d29c88 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -453,7 +453,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AnthropicService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AnthropicService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -486,7 +486,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException public void testInfer_SendsCompletionRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", @@ -579,7 +579,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AnthropicChatCompletionModelTests.createChatCompletionModel( getUrl(webServer), "secret", @@ -679,13 +679,13 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()))) { + try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } } private AnthropicService createServiceWithMockSender() { - return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 3d7ba7f7436fb..a14e805a0b5f8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -844,7 +844,7 @@ public void testParsePersistedConfig_WithoutSecretsCreatesChatCompletionModel() public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioChatCompletionModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -869,7 +869,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = AzureAiStudioEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -895,7 +895,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testUpdateModelWithChatCompletionDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioEmbeddingsModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -923,7 +923,7 @@ public void testUpdateModelWithChatCompletionDetails_NonNullSimilarityInOriginal private void testUpdateModelWithChatCompletionDetails_Successful(Integer maxNewTokens) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioChatCompletionModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -956,7 +956,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOExc var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -994,7 +994,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -1064,7 +1064,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1150,7 +1150,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep public void testInfer_WithChatCompletionModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testChatCompletionResultJson)); var model = AzureAiStudioChatCompletionModelTests.createModel( @@ -1187,7 +1187,7 @@ public void testInfer_WithChatCompletionModel() throws IOException { public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1264,7 +1264,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioChatCompletionModelTests.createModel( "id", getUrl(webServer), @@ -1396,7 +1396,7 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()))) { + try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1405,7 +1405,11 @@ public void testSupportsStreaming() throws IOException { // ---------------------------------------------------------------- private AzureAiStudioService createService() { - return new AzureAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new AzureAiStudioService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index de2e9ae9a21b8..f3d65c5589169 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -752,7 +752,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AzureOpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -785,7 +785,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep public void testInfer_SendsRequest() throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -844,7 +844,7 @@ public void testInfer_SendsRequest() throws IOException, URISyntaxException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureOpenAiCompletionModelTests.createModelWithRandomValues(); assertThrows( ElasticsearchStatusException.class, @@ -864,7 +864,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = AzureOpenAiEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -891,7 +891,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_UnauthorisedResponse() throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -952,7 +952,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException, URISyn private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1065,7 +1065,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureOpenAiCompletionModelTests.createCompletionModel( "resource", "deployment", @@ -1209,14 +1209,18 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()))) { + try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } } private AzureOpenAiService createAzureOpenAiService() { - return new AzureOpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new AzureOpenAiService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 52e4f904a4de0..8f189baa33b20 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -779,7 +779,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new CohereService(factory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -812,7 +812,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException public void testInfer_SendsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -886,7 +886,7 @@ public void testInfer_SendsRequest() throws IOException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = CohereCompletionModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThrows( ElasticsearchStatusException.class, @@ -906,7 +906,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var embeddingType = randomFrom(CohereEmbeddingType.values()); var model = CohereEmbeddingsModelTests.createModel( @@ -933,7 +933,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -975,7 +975,7 @@ public void testInfer_UnauthorisedResponse() throws IOException { public void testInfer_SetsInputTypeToIngest_FromInferParameter_WhenTaskSettingsAreEmpty() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1051,7 +1051,7 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1125,7 +1125,7 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest_v1API() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1200,7 +1200,7 @@ public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspec public void testInfer_DefaultsInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest_v2API() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1297,7 +1297,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 inputs String responseJson = """ @@ -1387,7 +1387,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 inputs String responseJson = """ @@ -1507,7 +1507,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1591,7 +1591,7 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new CohereService(mock(), createWithEmptySettings(mock()))) { + try (var service = new CohereService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1632,7 +1632,7 @@ private Map getRequestConfigMap(Map serviceSetti } private CohereService createCohereService() { - return new CohereService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new CohereService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index cc1bb4471c0a9..a707030a34189 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -53,6 +53,7 @@ import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -148,7 +149,7 @@ private static void assertCompletionModel(Model model) { public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - return new CustomService(senderFactory, createWithEmptySettings(threadPool)); + return new CustomService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private static Map createServiceSettingsMap(TaskType taskType) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index af38ee38e1eff..908451b8e681f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -360,7 +360,8 @@ public void testDoChunkedInferAlwaysFails() throws IOException { private DeepSeekService createService() { return new DeepSeekService( HttpRequestSenderTests.createSenderFactory(threadPool, clientManager), - createWithEmptySettings(threadPool) + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 6ce484954d3ce..94d1e064648ff 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -1427,7 +1427,8 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ createWithEmptySettings(threadPool), new ElasticInferenceServiceSettings(Settings.EMPTY), modelRegistry, - mockAuthHandler + mockAuthHandler, + mockClusterServiceEmpty() ); } @@ -1456,7 +1457,8 @@ private ElasticInferenceService createService( createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), modelRegistry, - mockAuthHandler + mockAuthHandler, + mockClusterServiceEmpty() ); } @@ -1469,7 +1471,8 @@ private ElasticInferenceService createServiceWithAuthHandler( createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), modelRegistry, - new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool) + new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool), + mockClusterServiceEmpty() ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 41175581df1cf..435ea9de5911b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -658,7 +658,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOEx var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -696,7 +696,7 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForModelThatD var model = GoogleAiStudioEmbeddingsModelTests.createModel("model", getUrl(webServer), "secret"); - try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -730,7 +730,7 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForModelThatD public void testInfer_SendsCompletionRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "candidates": [ @@ -818,7 +818,7 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "embeddings": [ @@ -897,7 +897,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "embeddings": [ @@ -998,7 +998,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed public void testInfer_ResourceNotFound() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1033,7 +1033,7 @@ public void testInfer_ResourceNotFound() throws IOException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = GoogleAiStudioCompletionModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThrows( ElasticsearchStatusException.class, @@ -1052,7 +1052,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = GoogleAiStudioEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -1124,7 +1124,7 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()))) { + try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1171,6 +1171,10 @@ private Map getRequestConfigMap( } private GoogleAiStudioService createGoogleAiStudioService() { - return new GoogleAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new GoogleAiStudioService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 99a09b983787d..26fd076e72462 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -1043,7 +1043,7 @@ public void testGetConfiguration() throws Exception { private GoogleVertexAiService createGoogleVertexAiService() { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - return new GoogleVertexAiService(senderFactory, createWithEmptySettings(threadPool)); + return new GoogleVertexAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index 3be4b72c1237f..2cdf3f5263751 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -29,6 +29,7 @@ import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.CoreMatchers.is; import static org.mockito.Mockito.mock; @@ -92,7 +93,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOExcep private static final class TestService extends HuggingFaceService { TestService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + super(factory, serviceComponents, mockClusterServiceEmpty()); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index 814d533129439..93156d4331263 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -81,7 +81,7 @@ public void shutdown() throws IOException { public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceElserService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceElserService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ [ @@ -137,7 +137,8 @@ public void testGetConfiguration() throws Exception { try ( var service = new HuggingFaceElserService( HttpRequestSenderTests.createSenderFactory(threadPool, clientManager), - createWithEmptySettings(threadPool) + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() ) ) { String content = XContentHelper.stripWhitespace(""" diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index e2850910ac64a..c770672c5d5f2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -258,7 +258,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); - try (var service = new HuggingFaceService(factory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -328,7 +328,7 @@ public void testUnifiedCompletionInfer() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -357,7 +357,7 @@ public void testUnifiedCompletionNonStreamingError() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( @@ -486,7 +486,7 @@ public void testUnifiedCompletionMalformedError() throws Exception { private void testStreamError(String expectedResponse) throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -548,7 +548,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -621,7 +621,7 @@ public void testInfer_StreamRequestRetry() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new HuggingFaceService(mock(), createWithEmptySettings(mock()))) { + try (var service = new HuggingFaceService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1009,7 +1009,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSetti public void testInfer_SendsEmbeddingsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1060,7 +1060,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret"); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -1087,7 +1087,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { public void testInfer_SendsElserRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ [ @@ -1139,7 +1139,7 @@ public void testInfer_SendsElserRequest() throws IOException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceElserModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThrows( ElasticsearchStatusException.class, @@ -1158,7 +1158,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = HuggingFaceEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -1179,7 +1179,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1233,7 +1233,7 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th public void testChunkedInfer() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ [ @@ -1340,7 +1340,11 @@ public void testGetConfiguration() throws Exception { } private HuggingFaceService createHuggingFaceService() { - return new HuggingFaceService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new HuggingFaceService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 3295ecfd4ece5..ddc62b5a412b9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -597,7 +597,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOExcept var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool))) { + try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -635,7 +635,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = IbmWatsonxEmbeddingsModelTests.createModel(modelId, projectId, URI.create(url), apiVersion, apiKey, getUrl(webServer)); - try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool))) { + try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( @@ -1018,12 +1018,12 @@ private Map getRequestConfigMap( } private IbmWatsonxService createIbmWatsonxService() { - return new IbmWatsonxService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new IbmWatsonxService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private static class IbmWatsonxServiceWithoutAuth extends IbmWatsonxService { IbmWatsonxServiceWithoutAuth(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + super(factory, serviceComponents, mockClusterServiceEmpty()); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index eca76bc1a702a..d36c574e0aa99 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -778,7 +778,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new JinaAIService(factory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -819,7 +819,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var embeddingType = randomFrom(JinaAIEmbeddingType.values()); var model = JinaAIEmbeddingsModelTests.createModel( @@ -846,7 +846,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_Embedding_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -889,7 +889,7 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { public void testInfer_Rerank_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -923,7 +923,7 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { public void testInfer_Embedding_Get_Response_Ingest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -994,7 +994,7 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { public void testInfer_Embedding_Get_Response_Search() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1065,7 +1065,7 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { public void testInfer_Embedding_Get_Response_clustering() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ {"model":"jina-clip-v2","object":"list","usage":{"total_tokens":5,"prompt_tokens":5}, @@ -1120,7 +1120,7 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { public void testInfer_Embedding_Get_Response_NullInputType() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1210,7 +1210,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1295,7 +1295,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1392,7 +1392,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, null); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1475,7 +1475,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, true); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1540,7 +1540,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1637,7 +1637,7 @@ public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOExcept private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 input String responseJson = """ @@ -1800,7 +1800,7 @@ public void testGetConfiguration() throws Exception { } public void testDoesNotSupportsStreaming() throws IOException { - try (var service = new JinaAIService(mock(), createWithEmptySettings(mock()))) { + try (var service = new JinaAIService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertFalse(service.canStream(TaskType.COMPLETION)); assertFalse(service.canStream(TaskType.ANY)); } @@ -1841,7 +1841,7 @@ private Map getRequestConfigMap(Map serviceSetti } private JinaAIService createJinaAIService() { - return new JinaAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new JinaAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 4ba9b8aa24394..8e170b25393e6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -249,7 +249,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); - try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -308,7 +308,7 @@ public void testUnifiedCompletionInfer() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -353,7 +353,7 @@ public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( @@ -421,7 +421,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = MistralChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -459,7 +459,7 @@ public void testInfer_StreamRequest_ErrorResponse() { } public void testSupportsStreaming() throws IOException { - try (var service = new MistralService(mock(), createWithEmptySettings(mock()))) { + try (var service = new MistralService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -942,7 +942,7 @@ public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenC public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = new Model(ModelConfigurationsTests.createRandomInstance()); assertThrows( @@ -962,7 +962,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = MistralEmbeddingModelTests.createModel( randomAlphaOfLength(10), @@ -990,7 +990,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws I var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1028,7 +1028,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", null, null, null, null); - try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( @@ -1086,7 +1086,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1173,7 +1173,7 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1276,7 +1276,7 @@ public void testGetConfiguration() throws Exception { // ---------------------------------------------------------------- private MistralService createService() { - return new MistralService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new MistralService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index c19eb664e88ac..83455861198d3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -847,7 +847,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -885,7 +885,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user"); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -924,7 +924,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { var mockModel = getInvalidModel("model_id", "service_name", TaskType.SPARSE_EMBEDDING); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -965,7 +965,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1003,7 +1003,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws public void testInfer_SendsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1099,7 +1099,7 @@ public void testUnifiedCompletionInfer() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -1132,7 +1132,7 @@ public void testUnifiedCompletionError() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( @@ -1189,7 +1189,7 @@ public void testMidStreamUnifiedCompletionError() throws Exception { private void testStreamError(String expectedResponse) throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -1267,7 +1267,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1344,7 +1344,7 @@ public void testInfer_StreamRequestRetry() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()))) { + try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1400,7 +1400,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1485,7 +1485,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // response with 2 embeddings String responseJson = """ @@ -1656,6 +1656,6 @@ public void testGetConfiguration() throws Exception { } private OpenAiService createOpenAiService() { - return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java index d7d9473f18084..bf883a6345398 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java @@ -47,6 +47,7 @@ import static org.elasticsearch.action.support.ActionTestUtils.assertNoSuccessListener; import static org.elasticsearch.core.TimeValue.THIRTY_SECONDS; import static org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequestTests.randomUnifiedCompletionRequest; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -84,7 +85,7 @@ public void init() { ThreadPool threadPool = mock(); when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); - sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of); + sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of, mockClusterServiceEmpty()); } public void testSupportedTaskTypes() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 8602621e9eb78..72a3b530ab647 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -718,7 +718,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotVoyageAIModel() throws IOExceptio var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -763,7 +763,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept "voyage-3-large" ); - try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( @@ -806,7 +806,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = VoyageAIEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -831,7 +831,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_Embedding_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -873,7 +873,7 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { public void testInfer_Rerank_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -907,7 +907,7 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { public void testInfer_Embedding_Get_Response_Ingest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -989,7 +989,7 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { public void testInfer_Embedding_Get_Response_Search() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1071,7 +1071,7 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { public void testInfer_Embedding_Get_Response_NullInputType() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1163,7 +1163,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, false, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1251,7 +1251,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, false, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1345,7 +1345,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, null, null); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1423,7 +1423,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, true, true); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1490,7 +1490,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1599,7 +1599,7 @@ public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOExcept private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 input String responseJson = """ @@ -1745,7 +1745,7 @@ public void testGetConfiguration() throws Exception { } public void testDoesNotSupportsStreaming() throws IOException { - try (var service = new VoyageAIService(mock(), createWithEmptySettings(mock()))) { + try (var service = new VoyageAIService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertFalse(service.canStream(TaskType.COMPLETION)); assertFalse(service.canStream(TaskType.ANY)); } @@ -1786,7 +1786,7 @@ private Map getRequestConfigMap(Map serviceSetti } private VoyageAIService createVoyageAIService() { - return new VoyageAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new VoyageAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } From 3a5ff0588ea3c01b99fbd5f06844bba2e5db9ae8 Mon Sep 17 00:00:00 2001 From: Samiul Monir Date: Wed, 16 Jul 2025 08:28:18 -0400 Subject: [PATCH 3/9] adding mock cluster service to fix IT test --- .../integration/InferenceRevokeDefaultEndpointsIT.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 8e40bba8b32f7..1eb530ac1bb9e 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -350,7 +350,8 @@ private ElasticInferenceService createElasticInferenceService() { createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(gatewayUrl), modelRegistry, - new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool) + new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool), + mockClusterServiceEmpty() ); } } From c1c8f4fa1fa93a2e842606526d7b74b094a3081e Mon Sep 17 00:00:00 2001 From: Samiul Monir Date: Wed, 16 Jul 2025 11:53:38 -0400 Subject: [PATCH 4/9] refactoring to remove duplication in constructors --- .../xpack/inference/services/SenderService.java | 5 +---- .../amazonbedrock/AmazonBedrockService.java | 5 ++--- .../services/elastic/ElasticInferenceService.java | 15 +++++---------- .../services/sagemaker/SageMakerService.java | 14 +------------- 4 files changed, 9 insertions(+), 30 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 7b7dbb7775623..4049bd115e226 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -51,10 +51,7 @@ public SenderService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - Objects.requireNonNull(factory); - sender = factory.createSender(); - this.serviceComponents = Objects.requireNonNull(serviceComponents); - this.clusterService = Objects.requireNonNull(context.clusterService()); + this(factory, serviceComponents, Objects.requireNonNull(context.clusterService())); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index dcb6a715c1942..0236be89e9f5d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -55,6 +55,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; @@ -98,11 +99,9 @@ public AmazonBedrockService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(httpSenderFactory, serviceComponents, context); - this.amazonBedrockSender = amazonBedrockFactory.createSender(); + this(httpSenderFactory, amazonBedrockFactory, serviceComponents, Objects.requireNonNull(context.clusterService())); } - // for testing public AmazonBedrockService( HttpRequestSender.Factory httpSenderFactory, AmazonBedrockRequestSender.Factory amazonBedrockFactory, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 3bac284f4f3cd..9ea53e5e3f91a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -72,6 +72,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.Set; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; @@ -144,19 +145,13 @@ public ElasticInferenceService( ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); - this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( - elasticInferenceServiceSettings.getElasticInferenceServiceUrl() - ); - authorizationHandler = new ElasticInferenceServiceAuthorizationHandler( + this( + factory, serviceComponents, + elasticInferenceServiceSettings, modelRegistry, authorizationRequestHandler, - initDefaultEndpoints(elasticInferenceServiceComponents), - IMPLEMENTED_TASK_TYPES, - this, - getSender(), - elasticInferenceServiceSettings + Objects.requireNonNull(context.clusterService()) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java index 27de36fa223d0..4455720258c3e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java @@ -68,21 +68,9 @@ public SageMakerService( CheckedSupplier, RuntimeException> configurationMap, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - this.modelBuilder = modelBuilder; - this.client = client; - this.schemas = schemas; - this.threadPool = threadPool; - this.configuration = new LazyInitializable<>( - () -> new InferenceServiceConfiguration.Builder().setService(NAME) - .setName(DISPLAY_NAME) - .setTaskTypes(supportedTaskTypes()) - .setConfigurations(configurationMap.get()) - .build() - ); - this.clusterService = Objects.requireNonNull(context.clusterService()); + this(modelBuilder, client, schemas, threadPool, configurationMap, Objects.requireNonNull(context.clusterService())); } - // for testing public SageMakerService( SageMakerModelBuilder modelBuilder, SageMakerClient client, From 27a96ac777b347ff266f27649af411873f227082 Mon Sep 17 00:00:00 2001 From: Samiul Monir Date: Wed, 16 Jul 2025 12:06:07 -0400 Subject: [PATCH 5/9] remove unnecessary blank line --- .../elasticsearch/xpack/inference/services/SenderService.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 4049bd115e226..20ea07a867409 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -52,7 +52,6 @@ public SenderService( InferenceServiceExtension.InferenceServiceFactoryContext context ) { this(factory, serviceComponents, Objects.requireNonNull(context.clusterService())); - } // for testing From c7666a3b4638e81ccda85d89be54a8a061b83981 Mon Sep 17 00:00:00 2001 From: Samiul Monir Date: Thu, 17 Jul 2025 13:45:06 -0400 Subject: [PATCH 6/9] refactor to have uniform constructor call --- .../xpack/inference/services/SenderService.java | 11 +---------- .../alibabacloudsearch/AlibabaCloudSearchService.java | 3 +-- .../services/amazonbedrock/AmazonBedrockService.java | 2 +- .../services/anthropic/AnthropicService.java | 3 +-- .../services/azureaistudio/AzureAiStudioService.java | 3 +-- .../services/azureopenai/AzureOpenAiService.java | 3 +-- .../inference/services/cohere/CohereService.java | 3 +-- .../inference/services/custom/CustomService.java | 3 +-- .../inference/services/deepseek/DeepSeekService.java | 3 +-- .../services/elastic/ElasticInferenceService.java | 3 +-- .../googleaistudio/GoogleAiStudioService.java | 3 +-- .../googlevertexai/GoogleVertexAiService.java | 3 +-- .../services/huggingface/HuggingFaceBaseService.java | 3 +-- .../services/huggingface/HuggingFaceService.java | 3 +-- .../huggingface/elser/HuggingFaceElserService.java | 3 +-- .../services/ibmwatsonx/IbmWatsonxService.java | 3 +-- .../inference/services/jinaai/JinaAIService.java | 3 +-- .../inference/services/mistral/MistralService.java | 3 +-- .../inference/services/openai/OpenAiService.java | 3 +-- .../inference/services/voyageai/VoyageAIService.java | 3 +-- 20 files changed, 20 insertions(+), 47 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 20ea07a867409..1057952d14166 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -46,20 +46,11 @@ public abstract class SenderService implements InferenceService { private final ServiceComponents serviceComponents; private final ClusterService clusterService; - public SenderService( - HttpRequestSender.Factory factory, - ServiceComponents serviceComponents, - InferenceServiceExtension.InferenceServiceFactoryContext context - ) { - this(factory, serviceComponents, Objects.requireNonNull(context.clusterService())); - } - - // for testing public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { Objects.requireNonNull(factory); sender = factory.createSender(); this.serviceComponents = Objects.requireNonNull(serviceComponents); - this.clusterService = clusterService; + this.clusterService = Objects.requireNonNull(clusterService); } public Sender getSender() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index ed08c979c0798..da608779fee0a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -92,10 +92,9 @@ public AlibabaCloudSearchService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // for testing public AlibabaCloudSearchService( HttpRequestSender.Factory factory, ServiceComponents serviceComponents, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 0236be89e9f5d..3652dd09cd694 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -99,7 +99,7 @@ public AmazonBedrockService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - this(httpSenderFactory, amazonBedrockFactory, serviceComponents, Objects.requireNonNull(context.clusterService())); + this(httpSenderFactory, amazonBedrockFactory, serviceComponents, context.clusterService()); } public AmazonBedrockService( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 5215916275bb3..8cf5446f8b6d5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -65,10 +65,9 @@ public AnthropicService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // for testing public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 979a0408521bd..e41ae47cf8983 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -90,10 +90,9 @@ public AzureAiStudioService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // for testing public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index a0e1675c30938..3d9a3dd516a2d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -76,10 +76,9 @@ public AzureOpenAiService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // for testing public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 4d9138304a5a0..fb6c630bd60c9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -91,10 +91,9 @@ public CohereService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // For testing public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 294b3775b6af7..5f5078affa9d3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -81,10 +81,9 @@ public CustomService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // for testing public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index c18ccbe7a2b13..8a77efbd604d2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -65,10 +65,9 @@ public DeepSeekService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // for testing public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 9ea53e5e3f91a..014847392274a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -151,11 +151,10 @@ public ElasticInferenceService( elasticInferenceServiceSettings, modelRegistry, authorizationRequestHandler, - Objects.requireNonNull(context.clusterService()) + context.clusterService() ); } - // for testing public ElasticInferenceService( HttpRequestSender.Factory factory, ServiceComponents serviceComponents, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index ad0ea5d878b80..4c8997f35555b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -89,10 +89,9 @@ public GoogleAiStudioService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // for testing public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 506c65f3191f0..2c2c667cd6eee 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -104,10 +104,9 @@ public GoogleVertexAiService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // for testing public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index 8d01fcab861fb..325f88c8904a3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -51,10 +51,9 @@ public HuggingFaceBaseService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // for testing public HuggingFaceBaseService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 63d97f9b1ef52..bc64e832d182a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -78,10 +78,9 @@ public HuggingFaceService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // for testing public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index ed08a53e0aace..5f9288bb99c24 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -64,10 +64,9 @@ public HuggingFaceElserService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // For testing public HuggingFaceElserService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 3527d14b9b9a6..9617bff0d3f3d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -90,10 +90,9 @@ public IbmWatsonxService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // for testing public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index e1025fa874c45..00e1aede95a2b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -83,10 +83,9 @@ public JinaAIService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // for testing public JinaAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 790afb766b024..3048847ea90d7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -91,10 +91,9 @@ public MistralService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // for testing public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 9ee494168bf7d..b9e9e34c44736 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -98,10 +98,9 @@ public OpenAiService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // for testing public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index f26974e525b43..9698ee4c0d4bb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -103,10 +103,9 @@ public VoyageAIService( ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents, context); + this(factory, serviceComponents, context.clusterService()); } - // for testing public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } From bfec42b3559c6aa62997d078953c5cee069deef4 Mon Sep 17 00:00:00 2001 From: Samiul Monir Date: Thu, 17 Jul 2025 13:47:06 -0400 Subject: [PATCH 7/9] refactor to have uniform constructor call for sagemaker --- .../xpack/inference/services/sagemaker/SageMakerService.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java index 4455720258c3e..653c4288263f9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java @@ -68,7 +68,7 @@ public SageMakerService( CheckedSupplier, RuntimeException> configurationMap, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - this(modelBuilder, client, schemas, threadPool, configurationMap, Objects.requireNonNull(context.clusterService())); + this(modelBuilder, client, schemas, threadPool, configurationMap, context.clusterService()); } public SageMakerService( @@ -90,7 +90,7 @@ public SageMakerService( .setConfigurations(configurationMap.get()) .build() ); - this.clusterService = clusterService; + this.clusterService = Objects.requireNonNull(clusterService); } @Override From 5422e22b5182946df80106d1fb4c09dd8b3dcb91 Mon Sep 17 00:00:00 2001 From: Samiul Monir Date: Thu, 17 Jul 2025 13:47:55 -0400 Subject: [PATCH 8/9] fix linting issues --- .../elasticsearch/xpack/inference/services/SenderService.java | 1 - .../inference/services/amazonbedrock/AmazonBedrockService.java | 1 - .../inference/services/elastic/ElasticInferenceService.java | 1 - 3 files changed, 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 1057952d14166..5074749c1cd9f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -18,7 +18,6 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 3652dd09cd694..c2b0ae8e69c37 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -55,7 +55,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 014847392274a..58e964bb5c25f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -72,7 +72,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Objects; import java.util.Set; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; From 227507ef4767db17c36255066de2b997888df3b3 Mon Sep 17 00:00:00 2001 From: Samiul Monir Date: Thu, 17 Jul 2025 14:46:08 -0400 Subject: [PATCH 9/9] fix failed unit tests --- .../services/azureaistudio/AzureAiStudioServiceTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 227e02da3e65b..3383762a9f332 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -1416,7 +1416,7 @@ public void testInfer_WithChatCompletionModel() throws IOException { public void testInfer_WithRerankModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testRerankTokenResponseJson)); var model = AzureAiStudioRerankModelTests.createModel(