diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiCacheBuildConfig.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiCacheBuildConfig.java new file mode 100644 index 000000000..47a13478e --- /dev/null +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiCacheBuildConfig.java @@ -0,0 +1,16 @@ +package io.quarkiverse.langchain4j.deployment; + +import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME; + +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; + +@ConfigRoot(phase = BUILD_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.cache") +public interface AiCacheBuildConfig { + + /** + * Ai Cache embedding model related settings + */ + CacheEmbeddingModelConfig embedding(); +} diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiCacheBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiCacheBuildItem.java new file mode 100644 index 000000000..294cdb028 --- /dev/null +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiCacheBuildItem.java @@ -0,0 +1,22 @@ +package io.quarkiverse.langchain4j.deployment; + +import io.quarkus.builder.item.SimpleBuildItem; + +public final class AiCacheBuildItem extends SimpleBuildItem { + + private boolean enable; + private String embeddingModelName; + + public AiCacheBuildItem(boolean enable, String embeddingModelName) { + this.enable = enable; + this.embeddingModelName = embeddingModelName; + } + + public boolean isEnable() { + return enable; + } + + public String getEmbeddingModelName() { + return embeddingModelName; + } +} diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiCacheProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiCacheProcessor.java new file mode 100644 index 000000000..b01ad8b51 --- /dev/null +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiCacheProcessor.java @@ -0,0 +1,81 @@ +package io.quarkiverse.langchain4j.deployment; + +import jakarta.enterprise.context.ApplicationScoped; + +import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.AnnotationTarget; +import org.jboss.jandex.ClassInfo; +import org.jboss.jandex.ClassType; +import org.jboss.jandex.IndexView; + +import dev.langchain4j.model.embedding.EmbeddingModel; +import io.quarkiverse.langchain4j.ModelName; +import io.quarkiverse.langchain4j.runtime.AiCacheRecorder; +import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; +import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider; +import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore; +import io.quarkiverse.langchain4j.runtime.cache.config.AiCacheConfig; +import io.quarkus.arc.deployment.SyntheticBeanBuildItem; +import io.quarkus.arc.deployment.UnremovableBeanBuildItem; +import io.quarkus.deployment.annotations.BuildProducer; +import io.quarkus.deployment.annotations.BuildStep; +import io.quarkus.deployment.annotations.ExecutionTime; +import io.quarkus.deployment.annotations.Record; +import io.quarkus.deployment.builditem.CombinedIndexBuildItem; + +public class AiCacheProcessor { + + @BuildStep + @Record(ExecutionTime.RUNTIME_INIT) + void setupBeans(AiCacheBuildConfig cacheBuildConfig, + AiCacheConfig cacheConfig, + AiCacheRecorder recorder, + CombinedIndexBuildItem indexBuildItem, + BuildProducer aiCacheBuildItemProducer, + BuildProducer unremovableProducer, + BuildProducer syntheticBeanProducer) { + + IndexView index = indexBuildItem.getIndex(); + boolean enableCache = false; + + for (AnnotationInstance instance : index.getAnnotations(LangChain4jDotNames.REGISTER_AI_SERVICES)) { + if (instance.target().kind() != AnnotationTarget.Kind.CLASS) { + continue; + } + + ClassInfo declarativeAiServiceClassInfo = instance.target().asClass(); + + if (declarativeAiServiceClassInfo.hasAnnotation(LangChain4jDotNames.CACHE_RESULT)) { + enableCache = true; + break; + } + } + + String embeddingModelName = NamedConfigUtil.DEFAULT_NAME; + if (cacheBuildConfig.embedding() != null) + embeddingModelName = cacheBuildConfig.embedding().name().orElse(NamedConfigUtil.DEFAULT_NAME); + + aiCacheBuildItemProducer.produce(new AiCacheBuildItem(enableCache, embeddingModelName)); + + if (enableCache) { + SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem + .configure(AiCacheProvider.class) + .setRuntimeInit() + .addInjectionPoint(ClassType.create(AiCacheStore.class)) + .scope(ApplicationScoped.class) + .createWith(recorder.messageWindow(cacheConfig, embeddingModelName)) + .defaultBean(); + + if (NamedConfigUtil.isDefault(embeddingModelName)) { + configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.EMBEDDING_MODEL)); + } else { + configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.EMBEDDING_MODEL), + AnnotationInstance.builder(ModelName.class).add("value", embeddingModelName).build()); + } + + syntheticBeanProducer.produce(configurator.done()); + unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(AiCacheStore.class)); + unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(EmbeddingModel.class)); + } + } +} diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index be1e10b61..01cb0064b 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -52,7 +52,6 @@ import org.objectweb.asm.tree.analysis.AnalyzerException; import dev.langchain4j.exception.IllegalConfigurationException; -import dev.langchain4j.service.Moderate; import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.ToolBox; import io.quarkiverse.langchain4j.deployment.config.LangChain4jBuildConfig; @@ -189,6 +188,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, Set chatModelNames = new HashSet<>(); Set moderationModelNames = new HashSet<>(); + for (AnnotationInstance instance : index.getAnnotations(LangChain4jDotNames.REGISTER_AI_SERVICES)) { if (instance.target().kind() != AnnotationTarget.Kind.CLASS) { continue; // should never happen @@ -210,14 +210,11 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, } String chatModelName = NamedConfigUtil.DEFAULT_NAME; + String moderationModelName = NamedConfigUtil.DEFAULT_NAME; + if (chatLanguageModelSupplierClassDotName == null) { AnnotationValue modelNameValue = instance.value("modelName"); - if (modelNameValue != null) { - String modelNameValueStr = modelNameValue.asString(); - if ((modelNameValueStr != null) && !modelNameValueStr.isEmpty()) { - chatModelName = modelNameValueStr; - } - } + chatModelName = getModelName(modelNameValue); chatModelNames.add(chatModelName); } @@ -243,6 +240,18 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, } } + // the default value depends on whether tools exists or not - if they do, then we require a AiCacheProvider bean + DotName aiCacheProviderSupplierClassDotName = LangChain4jDotNames.BEAN_AI_CACHE_PROVIDER_SUPPLIER; + AnnotationValue aiCacheProviderSupplierValue = instance.value("cacheProviderSupplier"); + if (aiCacheProviderSupplierValue != null) { + aiCacheProviderSupplierClassDotName = aiCacheProviderSupplierValue.asClass().name(); + if (!aiCacheProviderSupplierClassDotName + .equals(LangChain4jDotNames.BEAN_AI_CACHE_PROVIDER_SUPPLIER)) { + validateSupplierAndRegisterForReflection(aiCacheProviderSupplierClassDotName, index, + reflectiveClassProducer); + } + } + DotName retrieverClassDotName = null; AnnotationValue retrieverValue = instance.value("retriever"); if (retrieverValue != null) { @@ -296,17 +305,11 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, } // determine whether the method is annotated with @Moderate - String moderationModelName = NamedConfigUtil.DEFAULT_NAME; for (MethodInfo method : declarativeAiServiceClassInfo.methods()) { if (method.hasAnnotation(LangChain4jDotNames.MODERATE)) { if (moderationModelSupplierClassName.equals(LangChain4jDotNames.BEAN_IF_EXISTS_MODERATION_MODEL_SUPPLIER)) { AnnotationValue modelNameValue = instance.value("modelName"); - if (modelNameValue != null) { - String modelNameValueStr = modelNameValue.asString(); - if ((modelNameValueStr != null) && !modelNameValueStr.isEmpty()) { - moderationModelName = modelNameValueStr; - } - } + moderationModelName = getModelName(modelNameValue); moderationModelNames.add(moderationModelName); } break; @@ -325,13 +328,15 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, chatLanguageModelSupplierClassDotName, toolDotNames, chatMemoryProviderSupplierClassDotName, + aiCacheProviderSupplierClassDotName, retrieverClassDotName, retrievalAugmentorSupplierClassName, customRetrievalAugmentorSupplierClassIsABean, auditServiceSupplierClassName, moderationModelSupplierClassName, cdiScope, - chatModelName, moderationModelName)); + chatModelName, + moderationModelName)); } for (String chatModelName : chatModelNames) { @@ -365,7 +370,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, List declarativeAiServiceItems, List selectedChatModelProvider, BuildProducer syntheticBeanProducer, - BuildProducer unremoveableProducer) { + BuildProducer unremoveableProducer, + AiCacheBuildItem aiCacheBuildItem) { boolean needsChatModelBean = false; boolean needsStreamingChatModelBean = false; @@ -374,6 +380,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, boolean needsRetrievalAugmentorBean = false; boolean needsAuditServiceBean = false; boolean needsModerationModelBean = false; + boolean needsAiCacheProvider = false; + Set allToolNames = new HashSet<>(); for (DeclarativeAiServiceBuildItem bi : declarativeAiServiceItems) { @@ -390,6 +398,10 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, ? bi.getChatMemoryProviderSupplierClassDotName().toString() : null; + String aiCacheProviderSupplierClassName = bi.getAiCacheProviderSupplierClassDotName() != null + ? bi.getAiCacheProviderSupplierClassDotName().toString() + : null; + String retrieverClassName = bi.getRetrieverClassDotName() != null ? bi.getRetrieverClassDotName().toString() : null; @@ -407,7 +419,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, : null); // determine whether the method returns Multi - boolean injectStreamingChatModelBean = false; + boolean needsStreamingChatModel = false; for (MethodInfo method : declarativeAiServiceClassInfo.methods()) { if (!LangChain4jDotNames.MULTI.equals(method.returnType().name())) { continue; @@ -423,29 +435,41 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, throw illegalConfiguration("Only Multi is supported as a Multi return type. Offending method is '" + method.declaringClass().name().toString() + "#" + method.name() + "'"); } - injectStreamingChatModelBean = true; + needsStreamingChatModel = true; } - boolean injectModerationModelBean = false; + boolean needsModerationModel = false; for (MethodInfo method : declarativeAiServiceClassInfo.methods()) { - if (method.hasAnnotation(Moderate.class)) { - injectModerationModelBean = true; + if (method.hasAnnotation(LangChain4jDotNames.MODERATE)) { + needsModerationModel = true; break; } } String chatModelName = bi.getChatModelName(); String moderationModelName = bi.getModerationModelName(); + boolean enableCache = aiCacheBuildItem.isEnable(); + + // It is not possible to use the cache in combination with the tools. + if (!toolClassNames.isEmpty() && enableCache + && declarativeAiServiceClassInfo.hasAnnotation(LangChain4jDotNames.CACHE_RESULT)) { + throw new RuntimeException("The cache cannot be used in combination with the tools. Affected class: %s" + .formatted(serviceClassName)); + } + SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem .configure(QuarkusAiServiceContext.class) .forceApplicationClass() .createWith(recorder.createDeclarativeAiService( new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName, - toolClassNames, chatMemoryProviderSupplierClassName, retrieverClassName, + toolClassNames, chatMemoryProviderSupplierClassName, aiCacheProviderSupplierClassName, + retrieverClassName, retrievalAugmentorSupplierClassName, auditServiceClassSupplierName, moderationModelSupplierClassName, chatModelName, moderationModelName, - injectStreamingChatModelBean, injectModerationModelBean))) + needsStreamingChatModel, + needsModerationModel, + enableCache))) .setRuntimeInit() .addQualifier() .annotation(LangChain4jDotNames.QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER).addValue("value", serviceClassName) @@ -455,7 +479,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, if ((chatLanguageModelSupplierClassName == null) && !selectedChatModelProvider.isEmpty()) { if (NamedConfigUtil.isDefault(chatModelName)) { configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.CHAT_MODEL)); - if (injectStreamingChatModelBean) { + if (needsStreamingChatModel) { configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.STREAMING_CHAT_MODEL)); needsStreamingChatModelBean = true; } @@ -463,7 +487,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.CHAT_MODEL), AnnotationInstance.builder(ModelName.class).add("value", chatModelName).build()); - if (injectStreamingChatModelBean) { + if (needsStreamingChatModel) { configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.STREAMING_CHAT_MODEL), AnnotationInstance.builder(ModelName.class).add("value", chatModelName).build()); needsStreamingChatModelBean = true; @@ -519,7 +543,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, } if (LangChain4jDotNames.BEAN_IF_EXISTS_MODERATION_MODEL_SUPPLIER.toString() - .equals(moderationModelSupplierClassName) && injectModerationModelBean) { + .equals(moderationModelSupplierClassName) && needsModerationModel) { if (NamedConfigUtil.isDefault(moderationModelName)) { configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.MODERATION_MODEL)); @@ -531,6 +555,16 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, needsModerationModelBean = true; } + if (enableCache) { + + if (LangChain4jDotNames.BEAN_AI_CACHE_PROVIDER_SUPPLIER.toString().equals(aiCacheProviderSupplierClassName)) { + configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.AI_CACHE_PROVIDER)); + } else { + configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.AI_CACHE_PROVIDER)); + } + needsAiCacheProvider = true; + } + syntheticBeanProducer.produce(configurator.done()); } @@ -555,6 +589,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, if (needsModerationModelBean) { unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.MODERATION_MODEL)); } + if (needsAiCacheProvider) { + unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.AI_CACHE_PROVIDER)); + } if (!allToolNames.isEmpty()) { unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolNames)); } @@ -877,6 +914,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method, boolea boolean requiresModeration = method.hasAnnotation(LangChain4jDotNames.MODERATE); Class returnType = JandexUtil.load(method.returnType(), Thread.currentThread().getContextClassLoader()); + boolean requiresCache = method.declaringClass().hasDeclaredAnnotation(LangChain4jDotNames.CACHE_RESULT) + || method.hasDeclaredAnnotation(LangChain4jDotNames.CACHE_RESULT); List params = method.parameters(); @@ -914,7 +953,7 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method, boolea List methodToolClassNames = gatherMethodToolClassNames(method); return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo, - userMessageInfo, memoryIdParamPosition, requiresModeration, + userMessageInfo, memoryIdParamPosition, requiresModeration, requiresCache, returnType, metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames); } @@ -1249,6 +1288,16 @@ static Map toNameToArgsPositionMap(List } } + private String getModelName(AnnotationValue value) { + if (value != null) { + String modelNameValueStr = value.asString(); + if ((modelNameValueStr != null) && !modelNameValueStr.isEmpty()) { + return modelNameValueStr; + } + } + return NamedConfigUtil.DEFAULT_NAME; + } + public static final class AiServicesMethodBuildItem extends MultiBuildItem { private final MethodInfo methodInfo; diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/BeansProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/BeansProcessor.java index e7763d46f..5b3a28597 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/BeansProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/BeansProcessor.java @@ -66,7 +66,9 @@ void indexDependencies(BuildProducer producer) { } @BuildStep - public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished, + public void handleProviders( + AiCacheBuildItem aiCacheBuildItem, + BeanDiscoveryFinishedBuildItem beanDiscoveryFinished, List chatCandidateItems, List embeddingCandidateItems, List moderationCandidateItems, @@ -170,8 +172,35 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished selectedEmbeddingProducer.produce(new SelectedEmbeddingModelCandidateBuildItem(provider, modelName)); } } + + if (aiCacheBuildItem.isEnable() && !requestEmbeddingModels.contains(aiCacheBuildItem.getEmbeddingModelName())) { + + String modelName = aiCacheBuildItem.getEmbeddingModelName(); + String configNamespace; + Optional userSelectedProvider; + + if (NamedConfigUtil.isDefault(modelName)) { + userSelectedProvider = buildConfig.defaultConfig().embeddingModel().provider(); + configNamespace = "embedding-model"; + } else { + if (buildConfig.namedConfig().containsKey(modelName)) { + userSelectedProvider = buildConfig.namedConfig().get(modelName).embeddingModel().provider(); + } else { + userSelectedProvider = Optional.empty(); + } + configNamespace = modelName + ".embedding-model"; + } + + String provider = selectEmbeddingModelProvider(inProcessEmbeddingBuildItems, embeddingCandidateItems, + beanDiscoveryFinished.beanStream().withBeanType(EmbeddingModel.class), + userSelectedProvider, "EmbeddingModel", configNamespace); + selectedEmbeddingProducer + .produce(new SelectedEmbeddingModelCandidateBuildItem(provider, modelName)); + } + // If the Easy RAG extension requested to automatically generate an embedding model... - if (requestEmbeddingModels.isEmpty() && autoCreateEmbeddingModelBuildItem.isPresent()) { + if (requestEmbeddingModels.isEmpty() + && autoCreateEmbeddingModelBuildItem.isPresent()) { String provider = selectEmbeddingModelProvider(inProcessEmbeddingBuildItems, embeddingCandidateItems, beanDiscoveryFinished.beanStream().withBeanType(EmbeddingModel.class), Optional.empty(), "EmbeddingModel", "embedding-model"); diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/CacheEmbeddingModelConfig.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/CacheEmbeddingModelConfig.java new file mode 100644 index 000000000..3d9d83245 --- /dev/null +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/CacheEmbeddingModelConfig.java @@ -0,0 +1,14 @@ +package io.quarkiverse.langchain4j.deployment; + +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigGroup; + +@ConfigGroup +public interface CacheEmbeddingModelConfig { + + /** + * Name of the embedding model to use in the semantic cache. + */ + Optional name(); +} diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java index f37e67ee5..44e2a0582 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java @@ -17,6 +17,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem { private final List toolDotNames; private final DotName chatMemoryProviderSupplierClassDotName; + private final DotName aiCacheProviderSupplierClassDotName; private final DotName retrieverClassDotName; private final DotName retrievalAugmentorSupplierClassDotName; private final boolean customRetrievalAugmentorSupplierClassIsABean; @@ -29,6 +30,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem { public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languageModelSupplierClassDotName, List toolDotNames, DotName chatMemoryProviderSupplierClassDotName, + DotName aiCacheProviderSupplierClassDotName, DotName retrieverClassDotName, DotName retrievalAugmentorSupplierClassDotName, boolean customRetrievalAugmentorSupplierClassIsABean, @@ -41,6 +43,7 @@ public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languag this.languageModelSupplierClassDotName = languageModelSupplierClassDotName; this.toolDotNames = toolDotNames; this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName; + this.aiCacheProviderSupplierClassDotName = aiCacheProviderSupplierClassDotName; this.retrieverClassDotName = retrieverClassDotName; this.retrievalAugmentorSupplierClassDotName = retrievalAugmentorSupplierClassDotName; this.customRetrievalAugmentorSupplierClassIsABean = customRetrievalAugmentorSupplierClassIsABean; @@ -67,6 +70,10 @@ public DotName getChatMemoryProviderSupplierClassDotName() { return chatMemoryProviderSupplierClassDotName; } + public DotName getAiCacheProviderSupplierClassDotName() { + return aiCacheProviderSupplierClassDotName; + } + public DotName getRetrieverClassDotName() { return retrieverClassDotName; } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java index 264503f68..198ff1210 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java @@ -22,11 +22,13 @@ import dev.langchain4j.service.UserName; import dev.langchain4j.web.search.WebSearchEngine; import dev.langchain4j.web.search.WebSearchTool; +import io.quarkiverse.langchain4j.CacheResult; import io.quarkiverse.langchain4j.CreatedAware; import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.audit.AuditService; import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContextQualifier; +import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider; import io.smallrye.mutiny.Multi; public class LangChain4jDotNames { @@ -42,6 +44,7 @@ public class LangChain4jDotNames { static final DotName USER_MESSAGE = DotName.createSimple(UserMessage.class); static final DotName USER_NAME = DotName.createSimple(UserName.class); static final DotName MODERATE = DotName.createSimple(Moderate.class); + static final DotName CACHE_RESULT = DotName.createSimple(CacheResult.class); static final DotName MEMORY_ID = DotName.createSimple(MemoryId.class); static final DotName DESCRIPTION = DotName.createSimple(Description.class); static final DotName STRUCTURED_PROMPT = DotName.createSimple(StructuredPrompt.class); @@ -55,10 +58,14 @@ public class LangChain4jDotNames { RegisterAiService.BeanChatLanguageModelSupplier.class); static final DotName CHAT_MEMORY_PROVIDER = DotName.createSimple(ChatMemoryProvider.class); + static final DotName AI_CACHE_PROVIDER = DotName.createSimple(AiCacheProvider.class); static final DotName BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple( RegisterAiService.BeanChatMemoryProviderSupplier.class); + static final DotName BEAN_AI_CACHE_PROVIDER_SUPPLIER = DotName.createSimple( + RegisterAiService.BeanAiCacheProviderSupplier.class); + static final DotName NO_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple( RegisterAiService.NoChatMemoryProviderSupplier.class); diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/AutoCreateEmbeddingModelBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/AutoCreateEmbeddingModelBuildItem.java index 5dfc0faa5..b1aba4f63 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/AutoCreateEmbeddingModelBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/AutoCreateEmbeddingModelBuildItem.java @@ -4,8 +4,8 @@ /** * Request to generate an embedding model even if there are no - * non-synthetic injection points for it. This is used by the Easy RAG - * extension to have an embedding model created automatically. + * non-synthetic injection points for it. This is used by the Easy RAG and CacheResult + * to have an embedding model created automatically. */ public final class AutoCreateEmbeddingModelBuildItem extends SimpleBuildItem { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/CacheResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/CacheResult.java new file mode 100644 index 000000000..cfc8fd3ba --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/CacheResult.java @@ -0,0 +1,11 @@ +package io.quarkiverse.langchain4j; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(value = { ElementType.TYPE, ElementType.METHOD }) +public @interface CacheResult { +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java index 6e9df6744..164863c0c 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java @@ -20,45 +20,51 @@ import dev.langchain4j.store.memory.chat.ChatMemoryStore; import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore; import io.quarkiverse.langchain4j.audit.AuditService; +import io.quarkiverse.langchain4j.runtime.cache.AiCache; +import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider; +import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore; +import io.quarkiverse.langchain4j.runtime.cache.FixedAiCache; +import io.quarkiverse.langchain4j.runtime.cache.InMemoryAiCacheStore; /** - * Used to create LangChain4j's {@link AiServices} in a declarative manner that the application can then use simply by - * using the class as a CDI bean. - * Under the hood LangChain4j's {@link AiServices#builder(Class)} is called - * while also providing the builder with the proper {@link ChatLanguageModel} bean (mandatory), {@code tools} bean (optional), - * {@link ChatMemoryProvider} and {@link Retriever} beans (which by default are configured if such beans exist). + * Used to create LangChain4j's {@link AiServices} in a declarative manner that the application can then use simply by using the + * class as a CDI bean. Under the hood LangChain4j's {@link AiServices#builder(Class)} is called while also providing the + * builder + * with the proper {@link ChatLanguageModel} bean (mandatory), {@code tools} bean (optional), {@link ChatMemoryProvider} and + * {@link Retriever} beans (which by default are configured if such beans exist). *

* NOTE: The resulting CDI bean is {@link jakarta.enterprise.context.RequestScoped} by default. If you need to change the scope, - * simply annotate the class with a CDI scope. - * CAUTION: When using anything other than the request scope, you need to be very careful with the chat memory implementation. + * simply annotate the class with a CDI scope. CAUTION: When using anything other than the request scope, you need to be very + * careful with the chat memory implementation. *

- * NOTE: When the application also contains the {@code quarkus-micrometer} extension, metrics are automatically generated - * for the method invocations. + * NOTE: When the application also contains the {@code quarkus-micrometer} extension, metrics are automatically generated for + * the + * method invocations. */ @Retention(RUNTIME) @Target(ElementType.TYPE) public @interface RegisterAiService { /** - * Configures the way to obtain the {@link ChatLanguageModel} to use. - * If not configured, the default CDI bean implementing the model is looked up. - * Such a bean provided automatically by extensions such as {@code quarkus-langchain4j-openai}, - * {@code quarkus-langchain4j-azure-openai} or - * {@code quarkus-langchain4j-hugging-face} + * Configures the way to obtain the {@link ChatLanguageModel} to use. If not configured, the default CDI bean implementing + * the + * model is looked up. Such a bean provided automatically by extensions such as {@code quarkus-langchain4j-openai}, + * {@code quarkus-langchain4j-azure-openai} or {@code quarkus-langchain4j-hugging-face} */ Class> chatLanguageModelSupplier() default BeanChatLanguageModelSupplier.class; /** - * When {@code chatLanguageModelSupplier} is set to {@code BeanChatLanguageModelSupplier.class} (which is the default) - * this allows the selection of the {@link ChatLanguageModel} CDI bean to use. + * When {@code chatLanguageModelSupplier} is set to {@code BeanChatLanguageModelSupplier.class} (which is the default) this + * allows + * the selection of the {@link ChatLanguageModel} CDI bean to use. *

- * If not set, the default model (i.e. the one configured without setting the model name) is used. - * An example of the default model configuration is the following: - * {@code quarkus.langchain4j.openai.chat-model.model-name=gpt-4-turbo-preview} + * If not set, the default model (i.e. the one configured without setting the model name) is used. An example of the default + * model + * configuration is the following: {@code quarkus.langchain4j.openai.chat-model.model-name=gpt-4-turbo-preview} * - * If set, it uses the model configured by name. For example if this is set to {@code somename} - * an example configuration value for that named model could be: - * {@code quarkus.langchain4j.somename.openai.chat-model.model-name=gpt-4-turbo-preview} + * If set, it uses the model configured by name. For example if this is set to {@code somename} an example configuration + * value for + * that named model could be: {@code quarkus.langchain4j.somename.openai.chat-model.model-name=gpt-4-turbo-preview} */ String modelName() default ""; @@ -70,27 +76,49 @@ /** * Configures the way to obtain the {@link ChatMemoryProvider}. *

- * Be default, Quarkus configures a {@link ChatMemoryProvider} bean that uses a {@link InMemoryChatMemoryStore} bean - * as the backing store. The default type for the actual {@link ChatMemory} is {@link MessageWindowChatMemory} - * and it is configured with the value of the {@code quarkus.langchain4j.chat-memory.memory-window.max-messages} - * configuration property (which default to 10) as a way of limiting the number of messages in each chat. + * Be default, Quarkus configures a {@link ChatMemoryProvider} bean that uses a {@link InMemoryChatMemoryStore} bean as the + * backing store. The default type for the actual {@link ChatMemory} is {@link MessageWindowChatMemory} and it is configured + * with + * the value of the {@code quarkus.langchain4j.chat-memory.memory-window.max-messages} configuration property (which default + * to + * 10) as a way of limiting the number of messages in each chat. *

* If the application provides its own {@link ChatMemoryProvider} bean, that takes precedence over what Quarkus provides as - * the default. + * the + * default. *

* If the application provides an implementation of {@link ChatMemoryStore}, then that is used instead of the default * {@link InMemoryChatMemoryStore}. *

* In the most advances case, an arbitrary {@link ChatMemoryProvider} can be used by having a custom - * {@code Supplier} configured in this property. - * {@link Supplier} needs to be provided. + * {@code Supplier} configured in this property. {@link Supplier} needs to be + * provided. *

*/ Class> chatMemoryProviderSupplier() default BeanChatMemoryProviderSupplier.class; /** - * Configures the way to obtain the {@link Retriever} to use (when using RAG). - * By default, no retriever is used. + * Configures the way to obtain the {@link AiCacheProvider}. + *

+ * Be default, Quarkus configures a {@link AiCacheProvider} bean that uses a {@link InMemoryAiCacheStore} bean as the + * backing store. The default type for the actual {@link AiCache} is {@link FixedAiCache} and it is configured with + * the value of the {@code quarkus.langchain4j.cache.max-size} configuration property (which default to + * 1) as a way of limiting the number of messages in each cache. + *

+ * If the application provides its own {@link AiCacheProvider} bean, that takes precedence over what Quarkus provides as the + * default. + *

+ * If the application provides an implementation of {@link AiCacheStore}, then that is used instead of the default + * {@link InMemoryAiCacheStore}. + *

+ * In the most advances case, an arbitrary {@link AiCacheProvider} can be used by having a custom + * {@code Supplier} configured in this property. {@link Supplier} needs to be provided. + *

+ */ + Class> cacheProviderSupplier() default BeanAiCacheProviderSupplier.class; + + /** + * Configures the way to obtain the {@link Retriever} to use (when using RAG). By default, no retriever is used. * * @deprecated Use retrievalAugmentor instead */ @@ -98,38 +126,39 @@ Class> retriever() default NoRetriever.class; /** - * Configures the way to obtain the {@link RetrievalAugmentor} to use - * (when using RAG). The Supplier may or may not be a CDI bean (but most - * typically it will, so consider adding a bean-defining annotation to - * it). If it is not a CDI bean, Quarkus will create an instance - * by calling its no-arg constructor. + * Configures the way to obtain the {@link RetrievalAugmentor} to use (when using RAG). The Supplier may or may not be a CDI + * bean + * (but most typically it will, so consider adding a bean-defining annotation to it). If it is not a CDI bean, Quarkus will + * create + * an instance by calling its no-arg constructor. * - * If unspecified, Quarkus will attempt to locate a CDI bean that - * implements {@link RetrievalAugmentor} and use it if one exists. + * If unspecified, Quarkus will attempt to locate a CDI bean that implements {@link RetrievalAugmentor} and use it if one + * exists. */ Class> retrievalAugmentor() default BeanIfExistsRetrievalAugmentorSupplier.class; /** - * Configures the way to obtain the {@link AuditService} to use. - * By default, Quarkus will look for a CDI bean that implements {@link AuditService}, but will fall back to not using - * any memory if no such bean exists. - * If an arbitrary {@link AuditService} instance is needed, a custom implementation of - * {@link Supplier} needs to be provided. + * Configures the way to obtain the {@link AuditService} to use. By default, Quarkus will look for a CDI bean that + * implements + * {@link AuditService}, but will fall back to not using any memory if no such bean exists. If an arbitrary + * {@link AuditService} + * instance is needed, a custom implementation of {@link Supplier} needs to be provided. */ Class> auditServiceSupplier() default BeanIfExistsAuditServiceSupplier.class; /** - * Configures the way to obtain the {@link ModerationModel} to use. - * By default, Quarkus will look for a CDI bean that implements {@link ModerationModel} if at least one method is annotated - * with @Moderate. - * If an arbitrary {@link ModerationModel} instance is needed, a custom implementation of {@link Supplier} - * needs to be provided. + * Configures the way to obtain the {@link ModerationModel} to use. By default, Quarkus will look for a CDI bean that + * implements + * {@link ModerationModel} if at least one method is annotated with @Moderate. If an arbitrary {@link ModerationModel} + * instance is + * needed, a custom implementation of {@link Supplier} needs to be provided. */ Class> moderationModelSupplier() default BeanIfExistsModerationModelSupplier.class; /** - * Marker that is used to tell Quarkus to use the {@link ChatLanguageModel} that has been configured as a CDI bean by - * any of the extensions providing such capability (such as {@code quarkus-langchain4j-openai} and + * Marker that is used to tell Quarkus to use the {@link ChatLanguageModel} that has been configured as a CDI bean by any of + * the + * extensions providing such capability (such as {@code quarkus-langchain4j-openai} and * {@code quarkus-langchain4j-hugging-face}). */ final class BeanChatLanguageModelSupplier implements Supplier { @@ -141,11 +170,12 @@ public ChatLanguageModel get() { } /** - * Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean. - * Be default, Quarkus configures an {@link ChatMemoryProvider} by using an {@link InMemoryChatMemoryStore} - * as the backing store while using {@link MessageWindowChatMemory} with the value of - * configuration property {@code quarkus.langchain4j.chat-memory.memory-window.max-messages} (which default to 10) - * as a way of limiting the number of messages in each chat. + * Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean. Be default, Quarkus + * configures an {@link ChatMemoryProvider} by using an {@link InMemoryChatMemoryStore} as the backing store while using + * {@link MessageWindowChatMemory} with the value of configuration property + * {@code quarkus.langchain4j.chat-memory.memory-window.max-messages} (which default to 10) as a way of limiting the number + * of + * messages in each chat. */ final class BeanChatMemoryProviderSupplier implements Supplier { @@ -155,6 +185,21 @@ public ChatMemoryProvider get() { } } + /** + * Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean. Be default, Quarkus + * configures an {@link AiCacheProvider} by using an {@link InMemoryAiCacheStore} as the backing store while using + * {@link MessageWindowAiCacheMemory} with the value of configuration property + * {@code quarkus.langchain4j.cache.max-size} (which default to 1) as a way of limiting the number of + * messages in each cache. + */ + final class BeanAiCacheProviderSupplier implements Supplier { + + @Override + public AiCacheProvider get() { + throw new UnsupportedOperationException("should never be called"); + } + } + /** * Marker that is used when the user does not want any memory configured for the AiService */ @@ -178,8 +223,9 @@ public List findRelevant(String text) { } /** - * Marker that is used to tell Quarkus to use the {@link RetrievalAugmentor} that the user has configured as a CDI bean. - * If no such bean exists, then no retrieval augmentor will be used. + * Marker that is used to tell Quarkus to use the {@link RetrievalAugmentor} that the user has configured as a CDI bean. If + * no + * such bean exists, then no retrieval augmentor will be used. */ final class BeanIfExistsRetrievalAugmentorSupplier implements Supplier { @@ -190,8 +236,9 @@ public RetrievalAugmentor get() { } /** - * Marker that is used to tell Quarkus to not use any retrieval augmentor even if a CDI bean implementing - * the `RetrievalAugmentor` interface exists. + * Marker that is used to tell Quarkus to not use any retrieval augmentor even if a CDI bean implementing the + * `RetrievalAugmentor` + * interface exists. */ final class NoRetrievalAugmentorSupplier implements Supplier { @@ -202,8 +249,9 @@ public RetrievalAugmentor get() { } /** - * Marker that is used to tell Quarkus to use the {@link AuditService} that the user has configured as a CDI bean. - * If no such bean exists, then no audit service will be used. + * Marker that is used to tell Quarkus to use the {@link AuditService} that the user has configured as a CDI bean. If no + * such bean + * exists, then no audit service will be used. */ final class BeanIfExistsAuditServiceSupplier implements Supplier { @@ -214,8 +262,9 @@ public AuditService get() { } /** - * Marker that is used to tell Quarkus to use the {@link ModerationModel} that the user has configured as a CDI bean. - * If no such bean exists, then no audit service will be used. + * Marker that is used to tell Quarkus to use the {@link ModerationModel} that the user has configured as a CDI bean. If no + * such + * bean exists, then no moderation model will be used. */ final class BeanIfExistsModerationModelSupplier implements Supplier { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiCacheRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiCacheRecorder.java new file mode 100644 index 000000000..5f856a6e1 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiCacheRecorder.java @@ -0,0 +1,59 @@ +package io.quarkiverse.langchain4j.runtime; + +import java.time.Duration; +import java.util.function.Function; + +import dev.langchain4j.model.embedding.EmbeddingModel; +import io.quarkiverse.langchain4j.ModelName; +import io.quarkiverse.langchain4j.runtime.cache.AiCache; +import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider; +import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore; +import io.quarkiverse.langchain4j.runtime.cache.FixedAiCache; +import io.quarkiverse.langchain4j.runtime.cache.config.AiCacheConfig; +import io.quarkus.arc.SyntheticCreationalContext; +import io.quarkus.runtime.annotations.Recorder; + +@Recorder +public class AiCacheRecorder { + + public Function, AiCacheProvider> messageWindow(AiCacheConfig config, + String embeddingModelName) { + return new Function<>() { + @Override + public AiCacheProvider apply(SyntheticCreationalContext context) { + + EmbeddingModel embeddingModel; + AiCacheStore aiCacheStore = context.getInjectedReference(AiCacheStore.class); + + if (NamedConfigUtil.isDefault(embeddingModelName)) { + embeddingModel = context.getInjectedReference(EmbeddingModel.class); + } else { + embeddingModel = context.getInjectedReference(EmbeddingModel.class, + ModelName.Literal.of(embeddingModelName)); + } + + double threshold = config.threshold(); + int maxSize = config.maxSize(); + Duration ttl = config.ttl().orElse(null); + String queryPrefix = config.embedding().queryPrefix().orElse(""); + String passagePrefix = config.embedding().passagePrefix().orElse(""); + + return new AiCacheProvider() { + @Override + public AiCache get(Object memoryId) { + return FixedAiCache.Builder + .create(memoryId) + .ttl(ttl) + .maxSize(maxSize) + .threshold(threshold) + .queryPrefix(queryPrefix) + .passagePrefix(passagePrefix) + .embeddingModel(embeddingModel) + .store(aiCacheStore) + .build(); + } + }; + } + }; + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java index 000790b6b..740a97325 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java @@ -7,6 +7,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import java.util.function.Supplier; @@ -27,6 +28,7 @@ import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo; import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo; import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext; +import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider; import io.quarkus.arc.Arc; import io.quarkus.arc.SyntheticCreationalContext; import io.quarkus.runtime.annotations.Recorder; @@ -234,6 +236,26 @@ public T apply(SyntheticCreationalContext creationalContext) { } } + if (info.enableCache()) { + + if (info.aiCacheProviderSupplierClassName() != null) { + if (RegisterAiService.BeanAiCacheProviderSupplier.class.getName() + .equals(info.aiCacheProviderSupplierClassName())) { + aiServiceContext.aiCacheProvider = creationalContext.getInjectedReference( + AiCacheProvider.class); + } else { + Supplier supplier = (Supplier) Thread + .currentThread().getContextClassLoader() + .loadClass(info.aiCacheProviderSupplierClassName()) + .getConstructor().newInstance(); + aiServiceContext.aiCacheProvider = supplier.get(); + } + } + + if (aiServiceContext.aiCaches == null) + aiServiceContext.aiCaches = new ConcurrentHashMap<>(); + } + return (T) aiServiceContext; } catch (ClassNotFoundException e) { throw new IllegalStateException(e); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java index 998594985..7c313e86d 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java @@ -19,6 +19,7 @@ public final class AiServiceMethodCreateInfo { private final UserMessageInfo userMessageInfo; private final Optional memoryIdParamPosition; private final boolean requiresModeration; + private final boolean requiresCache; private final Class returnType; private final Optional metricsTimedInfo; private final Optional metricsCountedInfo; @@ -36,8 +37,7 @@ public AiServiceMethodCreateInfo(String interfaceName, String methodName, Optional systemMessageInfo, UserMessageInfo userMessageInfo, Optional memoryIdParamPosition, - boolean requiresModeration, - Class returnType, + boolean requiresModeration, boolean requiresCache, Class returnType, Optional metricsTimedInfo, Optional metricsCountedInfo, Optional spanInfo, @@ -49,6 +49,7 @@ public AiServiceMethodCreateInfo(String interfaceName, String methodName, this.userMessageInfo = userMessageInfo; this.memoryIdParamPosition = memoryIdParamPosition; this.requiresModeration = requiresModeration; + this.requiresCache = requiresCache; this.returnType = returnType; this.metricsTimedInfo = metricsTimedInfo; this.metricsCountedInfo = metricsCountedInfo; @@ -81,6 +82,10 @@ public boolean isRequiresModeration() { return requiresModeration; } + public boolean isRequiresCache() { + return requiresCache; + } + public Class getReturnType() { return returnType; } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index 38da8dcd4..457521ae2 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -48,6 +48,7 @@ import io.quarkiverse.langchain4j.audit.Audit; import io.quarkiverse.langchain4j.audit.AuditService; import io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil; +import io.quarkiverse.langchain4j.runtime.cache.AiCache; import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider; import io.smallrye.mutiny.Multi; import io.smallrye.mutiny.infrastructure.Infrastructure; @@ -121,7 +122,12 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob } Object memoryId = memoryId(methodCreateInfo, methodArgs, context.chatMemoryProvider != null); + AiCache cache = null; + if (methodCreateInfo.isRequiresCache()) { + Object cacheId = cacheId(methodCreateInfo); + cache = context.cache(cacheId); + } if (context.retrievalAugmentor != null) { // TODO extract method/class List chatMemory = context.hasChatMemory() ? context.chatMemory(memoryId).messages() @@ -171,9 +177,6 @@ public void accept(Response message) { } Future moderationFuture = triggerModerationIfNeeded(context, methodCreateInfo, messages); - - log.debug("Attempting to obtain AI response"); - List toolSpecifications = context.toolSpecifications; Map toolExecutors = context.toolExecutors; // override with method specific info @@ -182,16 +185,30 @@ public void accept(Response message) { toolExecutors = methodCreateInfo.getToolExecutors(); } - Response response = toolSpecifications == null - ? context.chatModel.generate(messages) - : context.chatModel.generate(messages, toolSpecifications); - log.debug("AI response obtained"); + Response response; + + if (cache != null) { + log.debug("Attempting to obtain AI response from cache"); + + var cacheResponse = cache.search(systemMessage.orElse(null), userMessage); + + if (cacheResponse.isPresent()) { + log.debug("Return cached response"); + response = Response.from(cacheResponse.get()); + } else { + response = executeLLMCall(context, messages, moderationFuture, toolSpecifications); + cache.add(systemMessage.orElse(null), userMessage, response.content()); + } + + } else { + response = executeLLMCall(context, messages, moderationFuture, toolSpecifications); + } + if (audit != null) { audit.addLLMToApplicationMessage(response); } - TokenUsage tokenUsageAccumulator = response.tokenUsage(); - verifyModerationIfNeeded(moderationFuture); + TokenUsage tokenUsageAccumulator = response.tokenUsage(); int executionsLeft = MAX_SEQUENTIAL_TOOL_EXECUTIONS; while (true) { @@ -245,6 +262,17 @@ public void accept(Response message) { return parse(response, returnType); } + private static Response executeLLMCall(QuarkusAiServiceContext context, List messages, + Future moderationFuture, List toolSpecifications) { + log.debug("Attempting to obtain AI response"); + var response = context.toolSpecifications == null + ? context.chatModel.generate(messages) + : context.chatModel.generate(messages, toolSpecifications); + log.debug("AI response obtained"); + verifyModerationIfNeeded(moderationFuture); + return response; + } + private static Future triggerModerationIfNeeded(AiServiceContext context, AiServiceMethodCreateInfo createInfo, List messages) { @@ -384,6 +412,17 @@ private static Object memoryId(AiServiceMethodCreateInfo createInfo, Object[] me return "default"; } + private static Object cacheId(AiServiceMethodCreateInfo createInfo) { + for (DefaultMemoryIdProvider provider : DEFAULT_MEMORY_ID_PROVIDERS) { + Object memoryId = provider.getMemoryId(); + if (memoryId != null) { + String perServiceSuffix = "#" + createInfo.getInterfaceName() + "." + createInfo.getMethodName(); + return memoryId + perServiceSuffix; + } + } + return "#" + createInfo.getInterfaceName() + "." + createInfo.getMethodName(); + } + // TODO: share these methods with LangChain4j private static String toString(Object arg) { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java index 0b8615079..4964277fc 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java @@ -6,6 +6,7 @@ public record DeclarativeAiServiceCreateInfo(String serviceClassName, String languageModelSupplierClassName, List toolsClassNames, String chatMemoryProviderSupplierClassName, + String aiCacheProviderSupplierClassName, String retrieverClassName, String retrievalAugmentorSupplierClassName, String auditServiceClassSupplierName, @@ -13,5 +14,6 @@ public record DeclarativeAiServiceCreateInfo(String serviceClassName, String chatModelName, String moderationModelName, boolean needsStreamingChatModel, - boolean needsModerationModel) { + boolean needsModerationModel, + boolean enableCache) { } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/InMemoryAiCacheStoreProducer.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/InMemoryAiCacheStoreProducer.java new file mode 100644 index 000000000..2f768373d --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/InMemoryAiCacheStoreProducer.java @@ -0,0 +1,21 @@ +package io.quarkiverse.langchain4j.runtime.aiservice; + +import jakarta.enterprise.inject.Produces; +import jakarta.inject.Singleton; + +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.runtime.cache.InMemoryAiCacheStore; +import io.quarkus.arc.DefaultBean; + +/** + * Creates the default {@link InMemoryAiCacheStoreProducer} store to be used by classes annotated with {@link RegisterAiService} + */ +public class InMemoryAiCacheStoreProducer { + + @Produces + @Singleton + @DefaultBean + public InMemoryAiCacheStore chatMemoryStore() { + return new InMemoryAiCacheStore(); + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java index d9edf9f91..c608d9bb9 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java @@ -1,15 +1,20 @@ package io.quarkiverse.langchain4j.runtime.aiservice; +import java.util.Map; import java.util.function.BiConsumer; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.service.AiServiceContext; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.audit.AuditService; +import io.quarkiverse.langchain4j.runtime.cache.AiCache; +import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider; public class QuarkusAiServiceContext extends AiServiceContext { public AuditService auditService; + public Map aiCaches; + public AiCacheProvider aiCacheProvider; // needed by Arc public QuarkusAiServiceContext() { @@ -20,12 +25,21 @@ public QuarkusAiServiceContext(Class aiServiceClass) { super(aiServiceClass); } + public boolean hasCache() { + return aiCaches != null; + } + + public AiCache cache(Object cacheId) { + return aiCaches.computeIfAbsent(cacheId, ignored -> aiCacheProvider.get(cacheId)); + } + /** * This is called by the {@code close} method of AiServices registered with {@link RegisterAiService} * when the bean's scope is closed */ public void close() { clearChatMemory(); + clearAiCache(); } private void clearChatMemory() { @@ -40,6 +54,18 @@ public void accept(Object memoryId, ChatMemory chatMemory) { } } + private void clearAiCache() { + if (aiCaches != null) { + aiCaches.forEach(new BiConsumer<>() { + @Override + public void accept(Object cacheId, AiCache aiCache) { + aiCache.clear(); + } + }); + aiCaches = null; + } + } + /** * This is called by the {@code remove(Object... ids)} method of AiServices when a user manually requests removal of chat * memories diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/AiCache.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/AiCache.java new file mode 100644 index 000000000..8b1b8bc76 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/AiCache.java @@ -0,0 +1,43 @@ +package io.quarkiverse.langchain4j.runtime.cache; + +import java.util.Optional; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; + +/** + * Represents the cache of a AI. It can be used to reduces response time for similar queries. + */ +public interface AiCache { + + /** + * The ID of the {@link AiCache}. + * + * @return The ID of the {@link AiCache}. + */ + Object id(); + + /** + * Cache a new message. + * + * @param systemMessage {@link SystemMessage} value to add to the cache. + * @param userMessage {@link UserMessage} value to add to the cache. + * @param aiResponse {@link AiMessage} value to add to the cache. + */ + void add(SystemMessage systemMessage, UserMessage userMessage, AiMessage aiResponse); + + /** + * Check if there is a response in the cache that is semantically close to the cached items. + * + * @param systemMessage {@link SystemMessage} value to find in the cache. + * @param userMessage {@link UserMessage} value to find in the cache. + * @return + */ + Optional search(SystemMessage systemMessage, UserMessage userMessage); + + /** + * Clears the cache. + */ + void clear(); +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/AiCacheProvider.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/AiCacheProvider.java new file mode 100644 index 000000000..d2e8279a1 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/AiCacheProvider.java @@ -0,0 +1,19 @@ +package io.quarkiverse.langchain4j.runtime.cache; + +import io.quarkiverse.langchain4j.RegisterAiService; + +/** + * Provides instances of {@link AiCache}. + * Intended to be used with {@link RegisterAiService} + */ +@FunctionalInterface +public interface AiCacheProvider { + + /** + * Provides an instance of {@link AiCache}. + * + * @param id The ID of the cache. + * @return A {@link AiCache} instance. + */ + AiCache get(Object id); +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/AiCacheStore.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/AiCacheStore.java new file mode 100644 index 000000000..abaee7e36 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/AiCacheStore.java @@ -0,0 +1,43 @@ +package io.quarkiverse.langchain4j.runtime.cache; + +import java.time.Instant; +import java.util.List; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.message.AiMessage; + +/** + * Represents a store for the {@link AiCache} state. + * Allows for flexibility in terms of where and how cache is stored. + */ +public interface AiCacheStore { + + public record CacheRecord(Embedding embedded, AiMessage response, Instant creation) { + public static CacheRecord of(Embedding embedded, AiMessage response) { + return new CacheRecord(embedded, response, Instant.now()); + } + }; + + /** + * Get all items stored in the cache. + * + * @param id Unique identifier for the cache + * @return {@link List} of {@link CacheRecord} + */ + public List getAll(Object id); + + /** + * Delete all items stored in the cache. + * + * @param id Unique identifier for the cache + */ + public void deleteCache(Object id); + + /** + * Update all items stored in the cache. + * + * @param id Unique identifier for the cache + * @param items Items to update + */ + public void updateCache(Object id, List items); +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/FixedAiCache.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/FixedAiCache.java new file mode 100644 index 000000000..fc512d854 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/FixedAiCache.java @@ -0,0 +1,204 @@ +package io.quarkiverse.langchain4j.runtime.cache; + +import java.time.Duration; +import java.util.Date; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.locks.ReentrantLock; +import java.util.stream.Collectors; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.CosineSimilarity; +import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore.CacheRecord; + +/** + * This {@link AiCache} default implementation. + */ +public class FixedAiCache implements AiCache { + + private final Object id; + private final Integer maxMessages; + private final AiCacheStore store; + private final Double threshold; + private final Duration ttl; + private final String queryPrefix; + private final String passagePrefix; + private final EmbeddingModel embeddingModel; + private final ReentrantLock lock; + + public FixedAiCache(Builder builder) { + this.id = builder.id; + this.maxMessages = builder.maxSize; + this.store = builder.store; + this.ttl = builder.ttl; + this.threshold = builder.threshold; + this.queryPrefix = builder.queryPrefix; + this.passagePrefix = builder.passagePrefix; + this.embeddingModel = builder.embeddingModel; + this.lock = new ReentrantLock(); + } + + @Override + public Object id() { + return id; + } + + @Override + public void add(SystemMessage systemMessage, UserMessage userMessage, AiMessage aiResponse) { + + if (Objects.isNull(userMessage) || Objects.isNull(aiResponse)) { + return; + } + + String query; + if (Objects.isNull(systemMessage) || Objects.isNull(systemMessage.text()) || systemMessage.text().isBlank()) + query = userMessage.text(); + else + query = "%s%s%s".formatted(passagePrefix, systemMessage.text(), userMessage.text()); + + try { + + lock.lock(); + + List elements = store.getAll(id) + .stream() + .filter(this::checkTTL) + .collect(Collectors.toList()); + + if (elements.size() == maxMessages) { + return; + } + + elements.add(CacheRecord.of(embeddingModel.embed(query).content(), aiResponse)); + store.updateCache(id, elements); + + } finally { + lock.unlock(); + } + } + + @Override + public Optional search(SystemMessage systemMessage, UserMessage userMessage) { + + if (Objects.isNull(userMessage)) + return Optional.empty(); + + String query; + if (Objects.isNull(systemMessage) || Objects.isNull(systemMessage.text()) || systemMessage.text().isBlank()) + query = userMessage.text(); + else + query = "%s%s%s".formatted(queryPrefix, systemMessage.text(), userMessage.text()); + + try { + + lock.lock(); + + double maxScore = 0; + AiMessage result = null; + List records = store.getAll(id) + .stream() + .filter(this::checkTTL) + .collect(Collectors.toList()); + + for (var record : records) { + + var relevanceScore = CosineSimilarity.between(embeddingModel.embed(query).content(), record.embedded()); + var score = (float) CosineSimilarity.fromRelevanceScore(relevanceScore); + + if (score >= threshold.doubleValue() && score >= maxScore) { + maxScore = score; + result = record.response(); + } + } + + store.updateCache(id, records); + return Optional.ofNullable(result); + + } finally { + lock.unlock(); + } + } + + private boolean checkTTL(CacheRecord record) { + + if (ttl == null) + return true; + + var expiredTime = Date.from(record.creation().plus(ttl)); + var currentTime = new Date(); + + if (currentTime.after(expiredTime)) { + return false; + } + + return true; + } + + @Override + public void clear() { + store.deleteCache(id); + } + + public static class Builder { + + Object id; + Integer maxSize; + AiCacheStore store; + Double threshold; + Duration ttl; + String queryPrefix; + String passagePrefix; + EmbeddingModel embeddingModel; + + private Builder(Object id) { + this.id = id; + } + + public static Builder create(Object id) { + return new Builder(id); + } + + public Builder maxSize(Integer maxSize) { + this.maxSize = maxSize; + return this; + } + + public Builder store(AiCacheStore store) { + this.store = store; + return this; + } + + public Builder threshold(Double threshold) { + this.threshold = threshold; + return this; + } + + public Builder ttl(Duration ttl) { + this.ttl = ttl; + return this; + } + + public Builder queryPrefix(String queryPrefix) { + this.queryPrefix = queryPrefix; + return this; + } + + public Builder passagePrefix(String passagePrefix) { + this.passagePrefix = passagePrefix; + return this; + } + + public Builder embeddingModel(EmbeddingModel embeddingModel) { + this.embeddingModel = embeddingModel; + return this; + } + + public AiCache build() { + return new FixedAiCache(this); + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/InMemoryAiCacheStore.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/InMemoryAiCacheStore.java new file mode 100644 index 000000000..9f3e21771 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/InMemoryAiCacheStore.java @@ -0,0 +1,34 @@ +package io.quarkiverse.langchain4j.runtime.cache; + +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Implementation of {@link AiCacheStore}. + *

+ * This storage mechanism is transient and does not persist data across application restarts. + */ +public class InMemoryAiCacheStore implements AiCacheStore { + + private final Map> store = new ConcurrentHashMap<>(); + + @Override + public List getAll(Object memoryId) { + var elements = store.get(memoryId); + if (elements == null) + return new LinkedList<>(); + return elements; + } + + @Override + public void deleteCache(Object memoryId) { + store.remove(memoryId); + } + + @Override + public void updateCache(Object memoryId, List elements) { + store.put(memoryId, elements); + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/config/AiCacheConfig.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/config/AiCacheConfig.java new file mode 100644 index 000000000..f0ba65ef1 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/config/AiCacheConfig.java @@ -0,0 +1,37 @@ +package io.quarkiverse.langchain4j.runtime.cache.config; + +import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME; + +import java.time.Duration; +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefault; + +@ConfigRoot(phase = RUN_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.cache") +public interface AiCacheConfig { + + /** + * Threshold used during semantic search to validate whether a cache result should be returned. + */ + @WithDefault("1") + double threshold(); + + /** + * Maximum number of messages to cache. + */ + @WithDefault("1") + int maxSize(); + + /** + * Time to live for messages stored in the cache. + */ + Optional ttl(); + + /** + * Allow to customize the embedding operation. + */ + AiCacheEmbeddingConfig embedding(); +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/config/AiCacheEmbeddingConfig.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/config/AiCacheEmbeddingConfig.java new file mode 100644 index 000000000..5e6839098 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/config/AiCacheEmbeddingConfig.java @@ -0,0 +1,19 @@ +package io.quarkiverse.langchain4j.runtime.cache.config; + +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigGroup; + +@ConfigGroup +public interface AiCacheEmbeddingConfig { + + /** + * Add a prefix to each \"query\" value before performing the embedding operation for the similarity search. + */ + Optional queryPrefix(); + + /** + * Add a prefix to each \"response\" value before performing the embedding operation. + */ + Optional passagePrefix(); +} diff --git a/docs/modules/ROOT/pages/ai-services.adoc b/docs/modules/ROOT/pages/ai-services.adoc index dee2dd5e3..504b4dc48 100644 --- a/docs/modules/ROOT/pages/ai-services.adoc +++ b/docs/modules/ROOT/pages/ai-services.adoc @@ -171,6 +171,174 @@ quarkus.langchain4j.openai.m1.api-key=sk-... quarkus.langchain4j.huggingface.m2.api-key=sk-... ---- +[#cache] +== Configuring the Cache + +If necessary, a semantic cache can be enabled to maintain a fixed number of questions and answers previously asked to the LLM, thus reducing the number of API calls. + +The `@CacheResult` annotation enables semantic caching and can be used at the class or method level. When used at the class level, it indicates that all methods of the AiService will perform a cache lookup before making a call to the LLM. This approach provides a convenient way to enable the caching for all methods of a `@RegisterAiService`. + +[source,java] +---- +@RegisterAiService +@CacheResult +@SystemMessage("...") +public interface LLMService { + // Cache is enabled for all methods + ... +} + +---- + +On the other hand, using `@CacheResult` at the method level allows fine-grained control over where the cache is enabled. + +[source,java] +---- +@RegisterAiService +@SystemMessage("...") +public interface LLMService { + + @CacheResult + @UserMessage("...") + public String method1(...); // Cache is enabled for this method + + @UserMessage("...") + public String method2(...); // Cache is not enabled for this method +} + +---- + +[IMPORTANT] +==== +Each method annotated with `@CacheResult` will have its own cache shared by all users. +==== + +=== Cache properties + +The following properties can be used to customize the cache configuration: + +- `quarkus.langchain4j.cache.threshold`: Specifies the threshold used during semantic search to determine whether a cached result should be returned. This threshold defines the similarity measure between new queries and cached entries. (`default 1`) +- `quarkus.langchain4j.cache.max-size`: Sets the maximum number of messages to cache. This property helps control memory usage by limiting the size of each cache. (`default 10`) +- `quarkus.langchain4j.cache.ttl`: Defines the time-to-live for messages stored in the cache. Messages that exceed the TTL are automatically removed. (`default 5m`) +- `quarkus.langchain4j.cache.embedding.name`: Specifies the name of the embedding model to use. +- `quarkus.langchain4j.cache.embedding.query-prefix`: Adds a prefix to each "query" value before performing the embedding operation. +- `quarkus.langchain4j.cache.embedding.response-prefix`: Adds a prefix to each "response" value before performing the embedding operation. + +By default, the cache uses the default embedding model provided by the LLM. If there are multiple embedding providers, the `quarkus.langchain4j.cache.embedding.name` property can be used to choose which one to use. + +In the following example, there are two different embedding providers + +`pom.xml`: + +[source,xml,subs=attributes+] +---- +... + + + io.quarkiverse.langchain4j + quarkus-langchain4j-openai + {project-version} + + + io.quarkiverse.langchain4j + quarkus-langchain4j-watsonx + {project-version} + + +... +---- + +`application.properties`: + +[source,properties] +---- +# OpenAI configuration +quarkus.langchain4j.service1.chat-model.provider=openai +quarkus.langchain4j.service1.embedding-model.provider=openai +quarkus.langchain4j.openai.service1.api-key=sk-... + +# Watsonx configuration +quarkus.langchain4j.service2.chat-model.provider=watsonx +quarkus.langchain4j.service2.embedding-model.provider=watsonx +quarkus.langchain4j.watsonx.service2.base-url=... +quarkus.langchain4j.watsonx.service2.api-key=... +quarkus.langchain4j.watsonx.service2.project-id=... +quarkus.langchain4j.watsonx.service2.embedding-model.model-id=... + +# The cache will use the embedding model provided by watsonx +quarkus.langchain4j.cache.embedding.name=service2 +---- + +When an xref:in-process-embedding.adoc[in-process embedding model] must to be used: + +`pom.xml`: + +[source,xml,subs=attributes+] +---- +... + + + io.quarkiverse.langchain4j + quarkus-langchain4j-openai + {project-version} + + + io.quarkiverse.langchain4j + quarkus-langchain4j-watsonx + {project-version} + + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2 + 0.31.0 + + + dev.langchain4j + langchain4j-core + + + + +... +---- + +`application.properties`: + +[source,properties] +---- +# OpenAI configuration +quarkus.langchain4j.service1.chat-model.provider=openai +quarkus.langchain4j.service1.embedding-model.provider=openai +quarkus.langchain4j.openai.service1.api-key=sk-... + +# Watsonx configuration +quarkus.langchain4j.service2.chat-model.provider=watsonx +quarkus.langchain4j.service2.embedding-model.provider=watsonx +quarkus.langchain4j.watsonx.service2.base-url=... +quarkus.langchain4j.watsonx.service2.api-key=... +quarkus.langchain4j.watsonx.service2.project-id=... +quarkus.langchain4j.watsonx.service2.embedding-model.model-id=... + +# The cache will use the in-process embedding model AllMiniLmL6V2EmbeddingModel +quarkus.langchain4j.embedding-model.provider=dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel +---- + +=== Advanced usage +The `cacheProviderSupplier` attribute of the `@RegisterAiService` annotation enables configuring the `AiCacheProvider`. The default value of this annotation is `RegisterAiService.BeanAiCacheProviderSupplier.class` which means that the AiService will use whatever `AiCacheProvider` bean is configured by the application or the default one provided by the extension. + +The extension provides a default implementation of `AiCacheProvider` which does two things: + +* It uses whatever bean `AiCacheStore` bean is configured, as the cache store. The default implementation is `InMemoryAiCacheStore`. +** If the application provides its own `AiCacheStore` bean, that will be used instead of the default `InMemoryAiCacheStore`. + +* It leverages the available configuration options under `quarkus.langchain4j.cache` to construct the `AiCacheProvider`. +** The default configuration values result in the usage of `FixedAiCache` with a size of ten. + +[source,java] +---- +@RegisterAiService(cacheProviderSupplier = CustomAiCacheProvider.class) +---- + [#memory] == Configuring the Context (Memory) @@ -288,10 +456,7 @@ This guidance aims to cover all crucial aspects of designing AI services with Qu By default, @RegisterAiService annotated interfaces don't moderate content. However, users can opt in to having the LLM moderate content by annotating the method with `@Moderate`. -For moderation to work, the following criteria need to be met: - -* A CDI bean for `dev.langchain4j.model.moderation.ModerationModel` must be configured (the `quarkus-langchain4j-openai` and `quarkus-langchain4j-azure-openai` provide one out of the box) -* The interface must be configured with `@RegisterAiService(moderationModelSupplier = RegisterAiService.BeanModerationModelSupplier.class)` +For moderation to work, a CDI bean for `dev.langchain4j.model.moderation.ModerationModel` must be configured (the `quarkus-langchain4j-openai` and `quarkus-langchain4j-azure-openai` provide one out of the box). === Advanced usage An alternative to providing a CDI bean is to configure the interface with `@RegisterAiService(moderationModelSupplier = MyCustomSupplier.class)` diff --git a/integration-tests/cache/pom.xml b/integration-tests/cache/pom.xml new file mode 100644 index 000000000..1df109793 --- /dev/null +++ b/integration-tests/cache/pom.xml @@ -0,0 +1,142 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-integration-tests-parent + 999-SNAPSHOT + + quarkus-langchain4j-integration-tests-cache + Quarkus LangChain4j - Integration Tests - Cache + + true + + + + io.quarkus + quarkus-resteasy-reactive-jackson + + + io.quarkiverse.langchain4j + quarkus-langchain4j-bam + ${project.version} + + + io.quarkiverse.langchain4j + quarkus-langchain4j-watsonx + ${project.version} + + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2 + ${langchain4j-embeddings.version} + + + dev.langchain4j + langchain4j-core + + + + + io.quarkus + quarkus-junit5 + test + + + io.rest-assured + rest-assured + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + io.quarkus + quarkus-devtools-testing + test + + + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core-deployment + ${project.version} + pom + test + + + * + * + + + + + org.junit.jupiter + junit-jupiter-api + test + + + + + + io.quarkus + quarkus-maven-plugin + + + + build + + + + + + maven-failsafe-plugin + + + + integration-test + verify + + + + + ${project.build.directory}/${project.build.finalName}-runner + org.jboss.logmanager.LogManager + ${maven.home} + + + + + + + + + + native-image + + + native + + + + + + maven-surefire-plugin + + ${native.surefire.skip} + + + + + + false + native + + + + \ No newline at end of file diff --git a/integration-tests/cache/src/main/java/org/acme/example/cache/MultiEmbedding/AiService1.java b/integration-tests/cache/src/main/java/org/acme/example/cache/MultiEmbedding/AiService1.java new file mode 100644 index 000000000..182eeea48 --- /dev/null +++ b/integration-tests/cache/src/main/java/org/acme/example/cache/MultiEmbedding/AiService1.java @@ -0,0 +1,11 @@ +package org.acme.example.cache.MultiEmbedding; + +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.CacheResult; +import io.quarkiverse.langchain4j.RegisterAiService; + +@RegisterAiService(modelName = "service1") +@CacheResult +public interface AiService1 { + public String poem(@UserMessage("{text}") String text); +} diff --git a/integration-tests/cache/src/main/java/org/acme/example/cache/MultiEmbedding/AiService2.java b/integration-tests/cache/src/main/java/org/acme/example/cache/MultiEmbedding/AiService2.java new file mode 100644 index 000000000..7faf70124 --- /dev/null +++ b/integration-tests/cache/src/main/java/org/acme/example/cache/MultiEmbedding/AiService2.java @@ -0,0 +1,11 @@ +package org.acme.example.cache.MultiEmbedding; + +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.CacheResult; +import io.quarkiverse.langchain4j.RegisterAiService; + +@RegisterAiService(modelName = "service2") +@CacheResult +public interface AiService2 { + public String poem(@UserMessage("{text}") String text); +} diff --git a/integration-tests/cache/src/main/java/org/acme/example/cache/MultiEmbedding/AiService3.java b/integration-tests/cache/src/main/java/org/acme/example/cache/MultiEmbedding/AiService3.java new file mode 100644 index 000000000..e1d29ab1f --- /dev/null +++ b/integration-tests/cache/src/main/java/org/acme/example/cache/MultiEmbedding/AiService3.java @@ -0,0 +1,33 @@ +package org.acme.example.cache.MultiEmbedding; + +import java.util.List; +import java.util.function.Supplier; + +import org.acme.example.cache.MultiEmbedding.AiService3.CustomChatLanguageModel; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.CacheResult; +import io.quarkiverse.langchain4j.RegisterAiService; + +@RegisterAiService(modelName = "service3", chatLanguageModelSupplier = CustomChatLanguageModel.class) +@CacheResult +public interface AiService3 { + public String poem(@UserMessage("{text}") String text); + + public static class CustomChatLanguageModel implements Supplier { + + @Override + public ChatLanguageModel get() { + return new ChatLanguageModel() { + @Override + public Response generate(List messages) { + return Response.from(AiMessage.from("Hello")); + } + }; + } + } +} diff --git a/integration-tests/cache/src/main/resources/application.properties b/integration-tests/cache/src/main/resources/application.properties new file mode 100644 index 000000000..19a2a64c2 --- /dev/null +++ b/integration-tests/cache/src/main/resources/application.properties @@ -0,0 +1,16 @@ +quarkus.langchain4j.service1.chat-model.provider=bam +quarkus.langchain4j.service1.embedding-model.provider=bam +quarkus.langchain4j.bam.service1.api-key=test +quarkus.langchain4j.bam.service1.base-url=http://test +quarkus.langchain4j.bam.service1.embedding-model.model-id=ibm/slate.125m.english.rtrvr +quarkus.langchain4j.service2.chat-model.provider=watsonx +quarkus.langchain4j.watsonx.service2.api-key=test +quarkus.langchain4j.watsonx.service2.base-url=http://test +quarkus.langchain4j.watsonx.service2.project-id=test +quarkus.langchain4j.watsonx.service2.embedding-model.model-id=ibm/slate.125m.english.rtrvr +quarkus.langchain4j.service3.chat-model.provider=bam +quarkus.langchain4j.service3.embedding-model.provider=bam +quarkus.langchain4j.bam.service3.api-key=test3 +quarkus.langchain4j.bam.service3.base-url=http://test3 + +quarkus.langchain4j.embedding-model.provider=dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel diff --git a/integration-tests/cache/src/test/java/org/acme/example/cache/MultiEmbedding/MultipleEmbeddingTest.java b/integration-tests/cache/src/test/java/org/acme/example/cache/MultiEmbedding/MultipleEmbeddingTest.java new file mode 100644 index 000000000..8cd67726f --- /dev/null +++ b/integration-tests/cache/src/test/java/org/acme/example/cache/MultiEmbedding/MultipleEmbeddingTest.java @@ -0,0 +1,40 @@ +package org.acme.example.cache.MultiEmbedding; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; + +import dev.langchain4j.model.embedding.EmbeddingModel; +import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore; +import io.quarkus.arc.Arc; +import io.quarkus.arc.ArcContainer; +import io.quarkus.arc.ManagedContext; +import io.quarkus.test.junit.QuarkusTest; + +@QuarkusTest +public class MultipleEmbeddingTest { + + @Inject + AiCacheStore aiCacheStore; + + @Inject + AiService3 service3; + + @Inject + EmbeddingModel embeddingModel; + + @Test + void test() { + + ArcContainer container = Arc.container(); + ManagedContext requestContext = container.requestContext(); + String cacheId = requestContext.getState() + "#" + AiService3.class.getName() + ".poem"; + + service3.poem("test"); + var messages = aiCacheStore.getAll(cacheId); + assertEquals(1, messages.size()); + assertEquals(embeddingModel.embed("test").content(), messages.get(0).embedded()); + } +} diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index b7b9f4165..b885d6ecf 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -20,6 +20,7 @@ azure-openai multiple-providers mistralai + cache devui devui-multiple-embedding-models embed-all-minilm-l6-v2-q diff --git a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheClassTest.java b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheClassTest.java new file mode 100644 index 000000000..1ba6987f6 --- /dev/null +++ b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheClassTest.java @@ -0,0 +1,123 @@ +package io.quarkiverse.langchain4j.bam.deployment; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.CacheResult; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore; +import io.quarkus.arc.Arc; +import io.quarkus.arc.ArcContainer; +import io.quarkus.arc.ManagedContext; +import io.quarkus.test.QuarkusUnitTest; + +public class CacheClassTest { + + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.base-url", WireMockUtil.URL) + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.api-key", WireMockUtil.API_KEY) + .setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); + + @RegisterAiService + @Singleton + @CacheResult + interface LLMService { + + @UserMessage("This is a userMessage {text}") + String chat(String text); + + @UserMessage("This is a userMessage {text}") + String chat2(String text); + + @Singleton + public class CustomChatModel implements ChatLanguageModel { + @Override + public Response generate(List messages) { + return Response.from(AiMessage.from("result")); + } + } + + @Singleton + public class CustomEmbedding implements EmbeddingModel { + + @Override + public Response> embedAll(List textSegments) { + return Response.from(List.of(Embedding.from(es))); + } + } + } + + @Inject + LLMService service; + + @Inject + AiCacheStore aiCacheStore; + + @Test + void cache_test() { + + String chatCacheId = "#" + LLMService.class.getName() + ".chat"; + String chat2CacheId = "#" + LLMService.class.getName() + ".chat2"; + + assertEquals(0, aiCacheStore.getAll(chatCacheId).size()); + assertEquals(0, aiCacheStore.getAll(chat2CacheId).size()); + + service.chat("chat"); + assertEquals(1, aiCacheStore.getAll(chatCacheId).size()); + assertEquals("result", aiCacheStore.getAll(chatCacheId).get(0).response().text()); + assertEquals(es, aiCacheStore.getAll(chatCacheId).get(0).embedded().vector()); + assertEquals(0, aiCacheStore.getAll(chat2CacheId).size()); + + service.chat2("chat2"); + assertEquals(1, aiCacheStore.getAll(chat2CacheId).size()); + assertEquals("result", aiCacheStore.getAll(chat2CacheId).get(0).response().text()); + assertEquals(es, aiCacheStore.getAll(chat2CacheId).get(0).embedded().vector()); + } + + @Test + @ActivateRequestContext + void cache_test_with_request_context() { + + ArcContainer container = Arc.container(); + ManagedContext requestContext = container.requestContext(); + String chatCacheId = requestContext.getState() + "#" + LLMService.class.getName() + ".chat"; + String chat2CacheId = requestContext.getState() + "#" + LLMService.class.getName() + ".chat2"; + + assertEquals(0, aiCacheStore.getAll(chatCacheId).size()); + service.chat("chat"); + assertEquals(1, aiCacheStore.getAll(chatCacheId).size()); + assertEquals("result", aiCacheStore.getAll(chatCacheId).get(0).response().text()); + assertEquals(es, aiCacheStore.getAll(chatCacheId).get(0).embedded().vector()); + + service.chat2("chat2"); + assertEquals(1, aiCacheStore.getAll(chat2CacheId).size()); + assertEquals("result", aiCacheStore.getAll(chat2CacheId).get(0).response().text()); + assertEquals(es, aiCacheStore.getAll(chat2CacheId).get(0).embedded().vector()); + } + + static float[] es = { + 0.039016734808683395f, + 0.010098248720169067f, + -0.02687959559261799f, + }; +} diff --git a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheConfigTest.java b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheConfigTest.java new file mode 100644 index 000000000..d8d02f0c4 --- /dev/null +++ b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheConfigTest.java @@ -0,0 +1,162 @@ +package io.quarkiverse.langchain4j.bam.deployment; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; +import java.util.stream.Collectors; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.service.SystemMessage; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.CacheResult; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore; +import io.quarkus.arc.Arc; +import io.quarkus.arc.ArcContainer; +import io.quarkus.arc.ManagedContext; +import io.quarkus.test.QuarkusUnitTest; + +public class CacheConfigTest { + + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.base-url", WireMockUtil.URL) + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.api-key", WireMockUtil.API_KEY) + .overrideRuntimeConfigKey("quarkus.langchain4j.cache.ttl", "2s") + .overrideRuntimeConfigKey("quarkus.langchain4j.cache.max-size", "3") + .setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); + + @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) + @Singleton + interface LLMService { + + @SystemMessage("TEST") + @UserMessage("{text}") + @CacheResult + String chat(String text); + + @Singleton + public class CustomChatModel implements ChatLanguageModel { + @Override + public Response generate(List messages) { + String m = messages.stream().map(ChatMessage::text).collect(Collectors.joining("")); + return Response.from(AiMessage.from("cache: " + m)); + } + } + + @Singleton + public class CustomEmbedding implements EmbeddingModel { + + @Override + public Response> embedAll(List textSegments) { + if (textSegments.get(0).text().equals("TESTFIRST")) + return Response.from(List.of(Embedding.from(first))); + else if (textSegments.get(0).text().equals("TESTSECOND")) + return Response.from(List.of(Embedding.from(second))); + else if (textSegments.get(0).text().equals("TESTTHIRD")) + return Response.from(List.of(Embedding.from(third))); + else if (textSegments.get(0).text().equals("TESTFOURTH")) + return Response.from(List.of(Embedding.from(fourth))); + return null; + } + } + } + + @Inject + LLMService service; + + @Inject + AiCacheStore aiCacheStore; + + @Test + @Order(1) + void cache_ttl_test() throws InterruptedException { + + String cacheId = "#" + LLMService.class.getName() + ".chat"; + aiCacheStore.deleteCache(cacheId); + + service.chat("FIRST"); + service.chat("SECOND"); + assertEquals(2, aiCacheStore.getAll(cacheId).size()); + assertEquals("cache: TESTFIRST", aiCacheStore.getAll(cacheId).get(0).response().text()); + assertEquals(first, aiCacheStore.getAll(cacheId).get(0).embedded().vector()); + assertEquals("cache: TESTSECOND", aiCacheStore.getAll(cacheId).get(1).response().text()); + assertEquals(second, aiCacheStore.getAll(cacheId).get(1).embedded().vector()); + + Thread.sleep(3000); + service.chat("THIRD"); + assertEquals(1, aiCacheStore.getAll(cacheId).size()); + assertEquals("cache: TESTTHIRD", aiCacheStore.getAll(cacheId).get(0).response().text()); + assertEquals(third, aiCacheStore.getAll(cacheId).get(0).embedded().vector()); + } + + @Test + @Order(2) + void cache_max_size_test() { + + String cacheId = "#" + LLMService.class.getName() + ".chat"; + aiCacheStore.deleteCache(cacheId); + + service.chat("FIRST"); + assertEquals(1, aiCacheStore.getAll(cacheId).size()); + assertEquals("cache: TESTFIRST", aiCacheStore.getAll(cacheId).get(0).response().text()); + assertEquals(first, aiCacheStore.getAll(cacheId).get(0).embedded().vector()); + + service.chat("SECOND"); + service.chat("THIRD"); + service.chat("FOURTH"); + assertEquals(3, aiCacheStore.getAll(cacheId).size()); + assertEquals("cache: TESTFIRST", aiCacheStore.getAll(cacheId).get(0).response().text()); + assertEquals(first, aiCacheStore.getAll(cacheId).get(0).embedded().vector()); + assertEquals("cache: TESTSECOND", aiCacheStore.getAll(cacheId).get(1).response().text()); + assertEquals(second, aiCacheStore.getAll(cacheId).get(1).embedded().vector()); + assertEquals("cache: TESTTHIRD", aiCacheStore.getAll(cacheId).get(2).response().text()); + assertEquals(third, aiCacheStore.getAll(cacheId).get(2).embedded().vector()); + } + + private String getContext(String methodName) { + ArcContainer container = Arc.container(); + ManagedContext requestContext = container.requestContext(); + return requestContext.getState() + "#" + LLMService.class.getName() + "." + methodName; + } + + static float[] first = { + 0.039016734808683395f, + 0.010098248720169067f, + -0.02687959559261799f, + }; + + static float[] second = { + 0.139016734108685515f, + 0.211198249720169167f, + 0.62687959559261799f, + }; + + static float[] third = { + -0.229016734199685515f, + -0.211198249721169127f, + -0.62999959559261719f, + }; + + static float[] fourth = { + -1.229016734199685515f, + 0.211198249721169127f, + 3.62999959559261719f, + }; +} diff --git a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheMethodTest.java b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheMethodTest.java new file mode 100644 index 000000000..3a8a91514 --- /dev/null +++ b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheMethodTest.java @@ -0,0 +1,122 @@ +package io.quarkiverse.langchain4j.bam.deployment; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.service.SystemMessage; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.CacheResult; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore; +import io.quarkus.arc.Arc; +import io.quarkus.arc.ArcContainer; +import io.quarkus.arc.ManagedContext; +import io.quarkus.test.QuarkusUnitTest; + +public class CacheMethodTest { + + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.base-url", WireMockUtil.URL) + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.api-key", WireMockUtil.API_KEY) + .setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); + + @RegisterAiService + @Singleton + interface LLMService { + + @SystemMessage("This is a systemMessage") + @UserMessage("This is a userMessage {text}") + @CacheResult + String chat(String text); + + @SystemMessage("This is a systemMessage") + @UserMessage("This is a userMessage {text}") + String chatNoCache(String text); + + @Singleton + public class CustomChatModel implements ChatLanguageModel { + @Override + public Response generate(List messages) { + return Response.from(AiMessage.from("result")); + } + } + + @Singleton + public class CustomEmbedding implements EmbeddingModel { + + @Override + public Response> embedAll(List textSegments) { + return Response.from(List.of(Embedding.from(es))); + } + } + } + + @Inject + LLMService service; + + @Inject + AiCacheStore aiCacheStore; + + @Test + void cache_test() { + + String chatCacheId = "#" + LLMService.class.getName() + ".chat"; + String chatNoCacheCacheId = "#" + LLMService.class.getName() + ".chatNoCache"; + + assertEquals(0, aiCacheStore.getAll(chatCacheId).size()); + assertEquals(0, aiCacheStore.getAll(chatNoCacheCacheId).size()); + service.chatNoCache("noCache"); + assertEquals(0, aiCacheStore.getAll(chatCacheId).size()); + assertEquals(0, aiCacheStore.getAll(chatNoCacheCacheId).size()); + + service.chat("cache"); + assertEquals(1, aiCacheStore.getAll(chatCacheId).size()); + assertEquals("result", aiCacheStore.getAll(chatCacheId).get(0).response().text()); + assertEquals(es, aiCacheStore.getAll(chatCacheId).get(0).embedded().vector()); + assertEquals(0, aiCacheStore.getAll(chatNoCacheCacheId).size()); + } + + @Test + @ActivateRequestContext + void cache_test_with_request_context() { + + ArcContainer container = Arc.container(); + ManagedContext requestContext = container.requestContext(); + String chatNoCacheId = requestContext.getState() + "#" + LLMService.class.getName() + ".chatNoCache"; + String chatCacheId = requestContext.getState() + "#" + LLMService.class.getName() + ".chat"; + + assertEquals(0, aiCacheStore.getAll(chatNoCacheId).size()); + service.chatNoCache("noCache"); + assertEquals(0, aiCacheStore.getAll(chatNoCacheId).size()); + + service.chat("cache"); + assertEquals(1, aiCacheStore.getAll(chatCacheId).size()); + assertEquals("result", aiCacheStore.getAll(chatCacheId).get(0).response().text()); + assertEquals(es, aiCacheStore.getAll(chatCacheId).get(0).embedded().vector()); + } + + static float[] es = { + 0.039016734808683395f, + 0.010098248720169067f, + -0.02687959559261799f + }; +} diff --git a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CachePrefixTest.java b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CachePrefixTest.java new file mode 100644 index 000000000..0984e6021 --- /dev/null +++ b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CachePrefixTest.java @@ -0,0 +1,120 @@ +package io.quarkiverse.langchain4j.bam.deployment; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; +import java.util.stream.Collectors; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.service.SystemMessage; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.CacheResult; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore; +import io.quarkus.test.QuarkusUnitTest; + +public class CachePrefixTest { + + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.base-url", WireMockUtil.URL) + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.api-key", WireMockUtil.API_KEY) + .overrideRuntimeConfigKey("quarkus.langchain4j.cache.embedding.passage-prefix", "passage: ") + .overrideRuntimeConfigKey("quarkus.langchain4j.cache.embedding.query-prefix", "query: ") + .overrideRuntimeConfigKey("quarkus.langchain4j.cache.ttl", "2s") + .overrideRuntimeConfigKey("quarkus.langchain4j.cache.max-size", "3") + .setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); + + @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) + @Singleton + interface LLMService { + + @SystemMessage("TEST") + @UserMessage("{text}") + @CacheResult + String chat(String text); + + @Singleton + public class CustomChatModel implements ChatLanguageModel { + @Override + public Response generate(List messages) { + String m = messages.stream().map(ChatMessage::text).collect(Collectors.joining("")); + return Response.from(AiMessage.from("cache: " + m)); + } + } + + @Singleton + public class CustomEmbedding implements EmbeddingModel { + + @Override + public Response> embedAll(List textSegments) { + + if (textSegments.get(0).text().equals("passage: TESTfirstMessage")) { + assertEquals("passage: TESTfirstMessage", textSegments.get(0).text()); + return Response.from(List.of(Embedding.from(firstMessage))); + } else if (textSegments.get(0).text().equals("query: TESTfirstMessage")) { + assertEquals("query: TESTfirstMessage", textSegments.get(0).text()); + return Response.from(List.of(Embedding.from(firstMessage))); + } else if (textSegments.get(0).text().equals("passage: TESTsecondMessage")) { + assertEquals("passage: TESTsecondMessage", textSegments.get(0).text()); + return Response.from(List.of(Embedding.from(secondMessage))); + } else if (textSegments.get(0).text().equals("query: TESTsecondMessage")) { + assertEquals("query: TESTsecondMessage", textSegments.get(0).text()); + return Response.from(List.of(Embedding.from(secondMessage))); + } + + return null; + } + } + } + + @Inject + LLMService service; + + @Inject + AiCacheStore aiCacheStore; + + @Test + @Order(1) + void cache_prefix_test() throws InterruptedException { + + String cacheId = "#" + LLMService.class.getName() + ".chat"; + aiCacheStore.deleteCache(cacheId); + + service.chat("firstMessage"); + service.chat("secondMessage"); + assertEquals(2, aiCacheStore.getAll(cacheId).size()); + assertEquals("cache: TESTfirstMessage", aiCacheStore.getAll(cacheId).get(0).response().text()); + assertEquals(firstMessage, aiCacheStore.getAll(cacheId).get(0).embedded().vector()); + assertEquals("cache: TESTsecondMessage", aiCacheStore.getAll(cacheId).get(1).response().text()); + assertEquals(secondMessage, aiCacheStore.getAll(cacheId).get(1).embedded().vector()); + } + + static float[] firstMessage = { + 0.039016734808683395f, + 0.010098248720169067f, + -0.02687959559261799f, + }; + + static float[] secondMessage = { + 0.139016734108685515f, + 0.211198249720169167f, + 0.62687959559261799f, + }; +} diff --git a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheWithToolTest.java b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheWithToolTest.java new file mode 100644 index 000000000..be3e06b7e --- /dev/null +++ b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheWithToolTest.java @@ -0,0 +1,38 @@ +package io.quarkiverse.langchain4j.bam.deployment; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.CacheResult; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkus.test.QuarkusUnitTest; + +public class CacheWithToolTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) + .assertException(t -> { + assertThat(t).isInstanceOf(RuntimeException.class); + assertEquals("The cache cannot be used in combination with the tools. Affected class: %s" + .formatted(AiService.class.getName()), t.getMessage()); + }); + + @RegisterAiService(tools = Object.class) + public interface AiService { + @CacheResult + public String poem(@UserMessage("{text}") String text); + } + + @Test + void test() { + fail("Should not be called"); + } +}