Skip to content

Commit 631e904

Browse files
committed
Compose the cache message with SystemMessage + UserMessage
1 parent 9747dc5 commit 631e904

File tree

16 files changed

+253
-2107
lines changed

16 files changed

+253
-2107
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiCacheBuildConfig.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ public interface AiCacheBuildConfig {
1212
/**
1313
* Ai Cache embedding model related settings
1414
*/
15-
CacheEmbeddingModelConfig embeddingModel();
15+
CacheEmbeddingModelConfig embedding();
1616
}

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiCacheProcessor.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import org.jboss.jandex.ClassType;
99
import org.jboss.jandex.IndexView;
1010

11+
import dev.langchain4j.model.embedding.EmbeddingModel;
12+
import io.quarkiverse.langchain4j.ModelName;
1113
import io.quarkiverse.langchain4j.runtime.AiCacheRecorder;
1214
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
1315
import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider;
@@ -49,23 +51,31 @@ void setupBeans(AiCacheBuildConfig cacheBuildConfig,
4951
}
5052
}
5153

52-
String embeddingModel = NamedConfigUtil.DEFAULT_NAME;
53-
if (cacheBuildConfig.embeddingModel() != null)
54-
embeddingModel = cacheBuildConfig.embeddingModel().name().orElse(NamedConfigUtil.DEFAULT_NAME);
54+
String embeddingModelName = NamedConfigUtil.DEFAULT_NAME;
55+
if (cacheBuildConfig.embedding() != null)
56+
embeddingModelName = cacheBuildConfig.embedding().name().orElse(NamedConfigUtil.DEFAULT_NAME);
5557

56-
aiCacheBuildItemProducer.produce(new AiCacheBuildItem(enableCache, embeddingModel));
58+
aiCacheBuildItemProducer.produce(new AiCacheBuildItem(enableCache, embeddingModelName));
5759

5860
if (enableCache) {
59-
var configurator = SyntheticBeanBuildItem
61+
SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
6062
.configure(AiCacheProvider.class)
6163
.setRuntimeInit()
6264
.addInjectionPoint(ClassType.create(AiCacheStore.class))
6365
.scope(ApplicationScoped.class)
64-
.createWith(recorder.messageWindow(cacheConfig))
66+
.createWith(recorder.messageWindow(cacheConfig, embeddingModelName))
6567
.defaultBean();
6668

69+
if (NamedConfigUtil.isDefault(embeddingModelName)) {
70+
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.EMBEDDING_MODEL));
71+
} else {
72+
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.EMBEDDING_MODEL),
73+
AnnotationInstance.builder(ModelName.class).add("value", embeddingModelName).build());
74+
}
75+
6776
syntheticBeanProducer.produce(configurator.done());
6877
unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(AiCacheStore.class));
78+
unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(EmbeddingModel.class));
6979
}
7080
}
7181
}

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,6 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
444444

445445
String chatModelName = bi.getChatModelName();
446446
String moderationModelName = bi.getModerationModelName();
447-
String aiCacheEmbeddingModelName = aiCacheBuildItem.getEmbeddingModelName();
448447
boolean enableCache = aiCacheBuildItem.isEnable();
449448

450449
// It is not possible to use the cache in combination with the tools.
@@ -464,7 +463,6 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
464463
retrievalAugmentorSupplierClassName,
465464
auditServiceClassSupplierName, moderationModelSupplierClassName, chatModelName,
466465
moderationModelName,
467-
aiCacheEmbeddingModelName,
468466
needsStreamingChatModel,
469467
needsModerationModel,
470468
enableCache)))
@@ -560,13 +558,6 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
560558
} else {
561559
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.AI_CACHE_PROVIDER));
562560
}
563-
564-
if (NamedConfigUtil.isDefault(aiCacheEmbeddingModelName)) {
565-
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.EMBEDDING_MODEL));
566-
} else {
567-
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.EMBEDDING_MODEL),
568-
AnnotationInstance.builder(ModelName.class).add("value", aiCacheEmbeddingModelName).build());
569-
}
570561
needsAiCacheProvider = true;
571562
}
572563

@@ -596,7 +587,6 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
596587
}
597588
if (needsAiCacheProvider) {
598589
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.AI_CACHE_PROVIDER));
599-
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.EMBEDDING_MODEL));
600590
}
601591
if (!allToolNames.isEmpty()) {
602592
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolNames));

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiCacheRecorder.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import java.time.Duration;
44
import java.util.function.Function;
55

6+
import dev.langchain4j.model.embedding.EmbeddingModel;
7+
import io.quarkiverse.langchain4j.ModelName;
68
import io.quarkiverse.langchain4j.runtime.cache.AiCache;
79
import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider;
810
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore;
@@ -14,15 +16,27 @@
1416
@Recorder
1517
public class AiCacheRecorder {
1618

17-
public Function<SyntheticCreationalContext<AiCacheProvider>, AiCacheProvider> messageWindow(AiCacheConfig config) {
19+
public Function<SyntheticCreationalContext<AiCacheProvider>, AiCacheProvider> messageWindow(AiCacheConfig config,
20+
String embeddingModelName) {
1821
return new Function<>() {
1922
@Override
2023
public AiCacheProvider apply(SyntheticCreationalContext<AiCacheProvider> context) {
2124

25+
EmbeddingModel embeddingModel;
2226
AiCacheStore aiCacheStore = context.getInjectedReference(AiCacheStore.class);
27+
28+
if (NamedConfigUtil.isDefault(embeddingModelName)) {
29+
embeddingModel = context.getInjectedReference(EmbeddingModel.class);
30+
} else {
31+
embeddingModel = context.getInjectedReference(EmbeddingModel.class,
32+
ModelName.Literal.of(embeddingModelName));
33+
}
34+
2335
double threshold = config.threshold();
2436
int maxSize = config.maxSize();
2537
Duration ttl = config.ttl().orElse(null);
38+
String queryPrefix = config.embedding().queryPrefix().orElse("");
39+
String passagePrefix = config.embedding().passagePrefix().orElse("");
2640

2741
return new AiCacheProvider() {
2842
@Override
@@ -32,6 +46,9 @@ public AiCache get(Object memoryId) {
3246
.ttl(ttl)
3347
.maxSize(maxSize)
3448
.threshold(threshold)
49+
.queryPrefix(queryPrefix)
50+
.passagePrefix(passagePrefix)
51+
.embeddingModel(embeddingModel)
3552
.store(aiCacheStore)
3653
.build();
3754
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import dev.langchain4j.memory.chat.ChatMemoryProvider;
1919
import dev.langchain4j.model.chat.ChatLanguageModel;
2020
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
21-
import dev.langchain4j.model.embedding.EmbeddingModel;
2221
import dev.langchain4j.model.moderation.ModerationModel;
2322
import dev.langchain4j.rag.RetrievalAugmentor;
2423
import dev.langchain4j.retriever.Retriever;
@@ -253,13 +252,6 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
253252
}
254253
}
255254

256-
if (NamedConfigUtil.isDefault(info.aiCacheEmbeddingModelName())) {
257-
aiServiceContext.embeddingModel = creationalContext.getInjectedReference(EmbeddingModel.class);
258-
} else {
259-
aiServiceContext.embeddingModel = creationalContext.getInjectedReference(EmbeddingModel.class,
260-
ModelName.Literal.of(info.aiCacheEmbeddingModelName()));
261-
}
262-
263255
aiServiceContext.aiCaches = new ConcurrentHashMap<>();
264256
}
265257

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import dev.langchain4j.agent.tool.ToolExecutionRequest;
2727
import dev.langchain4j.agent.tool.ToolExecutor;
2828
import dev.langchain4j.agent.tool.ToolSpecification;
29-
import dev.langchain4j.data.embedding.Embedding;
3029
import dev.langchain4j.data.message.AiMessage;
3130
import dev.langchain4j.data.message.ChatMessage;
3231
import dev.langchain4j.data.message.SystemMessage;
@@ -197,15 +196,14 @@ public void accept(Response<AiMessage> message) {
197196
if (cache != null) {
198197
log.debug("Attempting to obtain AI response from cache");
199198

200-
Embedding query = context.embeddingModel.embed(userMessage.text()).content();
201-
var cacheResponse = cache.search(query);
199+
var cacheResponse = cache.search(systemMessage.orElse(null), userMessage);
202200

203201
if (cacheResponse.isPresent()) {
204202
log.debug("Return cached response");
205203
response = Response.from(cacheResponse.get());
206204
} else {
207205
response = executeLLMCall(context, messages, moderationFuture, toolSpecifications);
208-
cache.add(query, response.content());
206+
cache.add(systemMessage.orElse(null), userMessage, response.content());
209207
}
210208

211209
} else {

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ public record DeclarativeAiServiceCreateInfo(String serviceClassName,
1313
String moderationModelSupplierClassName,
1414
String chatModelName,
1515
String moderationModelName,
16-
String aiCacheEmbeddingModelName,
1716
boolean needsStreamingChatModel,
1817
boolean needsModerationModel,
1918
boolean enableCache) {

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import java.util.function.BiConsumer;
55

66
import dev.langchain4j.memory.ChatMemory;
7-
import dev.langchain4j.model.embedding.EmbeddingModel;
87
import dev.langchain4j.service.AiServiceContext;
98
import io.quarkiverse.langchain4j.RegisterAiService;
109
import io.quarkiverse.langchain4j.audit.AuditService;
@@ -16,7 +15,6 @@ public class QuarkusAiServiceContext extends AiServiceContext {
1615
public AuditService auditService;
1716
public Map<Object, AiCache> aiCaches;
1817
public AiCacheProvider aiCacheProvider;
19-
public EmbeddingModel embeddingModel;
2018

2119
// needed by Arc
2220
public QuarkusAiServiceContext() {

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/AiCache.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import java.util.Optional;
44

5-
import dev.langchain4j.data.embedding.Embedding;
65
import dev.langchain4j.data.message.AiMessage;
6+
import dev.langchain4j.data.message.SystemMessage;
7+
import dev.langchain4j.data.message.UserMessage;
78

89
/**
910
* Represents the cache of a AI. It can be used to reduces response time for similar queries.
@@ -20,18 +21,20 @@ public interface AiCache {
2021
/**
2122
* Cache a new message.
2223
*
23-
* @param query Embedded value to add to the cache.
24-
* @param response Response returned by the AI to add to the cache.
24+
* @param systemMessage {@link SystemMessage} value to add to the cache.
25+
* @param userMessage {@link UserMessage} value to add to the cache.
26+
* @param aiResponse {@link AiMessage} value to add to the cache.
2527
*/
26-
void add(Embedding query, AiMessage response);
28+
void add(SystemMessage systemMessage, UserMessage userMessage, AiMessage aiResponse);
2729

2830
/**
29-
* Check to see if there is a response in the cache that is semantically close to a cached query.
31+
* Check if there is a response in the cache that is semantically close to the cached items.
3032
*
31-
* @param query
33+
* @param systemMessage {@link SystemMessage} value to find in the cache.
34+
* @param userMessage {@link UserMessage} value to find in the cache.
3235
* @return
3336
*/
34-
Optional<AiMessage> search(Embedding query);
37+
Optional<AiMessage> search(SystemMessage systemMessage, UserMessage userMessage);
3538

3639
/**
3740
* Clears the cache.

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/MessageWindowAiCache.java

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import java.util.Optional;
99
import java.util.concurrent.locks.ReentrantLock;
1010

11-
import dev.langchain4j.data.embedding.Embedding;
1211
import dev.langchain4j.data.message.AiMessage;
12+
import dev.langchain4j.data.message.SystemMessage;
13+
import dev.langchain4j.data.message.UserMessage;
14+
import dev.langchain4j.model.embedding.EmbeddingModel;
1315
import dev.langchain4j.store.embedding.CosineSimilarity;
1416
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore.CacheRecord;
1517

@@ -23,6 +25,9 @@ public class MessageWindowAiCache implements AiCache {
2325
private final AiCacheStore store;
2426
private final Double threshold;
2527
private final Duration ttl;
28+
private final String queryPrefix;
29+
private final String passagePrefix;
30+
private final EmbeddingModel embeddingModel;
2631
private final ReentrantLock lock;
2732

2833
public MessageWindowAiCache(Builder builder) {
@@ -31,6 +36,9 @@ public MessageWindowAiCache(Builder builder) {
3136
this.store = builder.store;
3237
this.ttl = builder.ttl;
3338
this.threshold = builder.threshold;
39+
this.queryPrefix = builder.queryPrefix;
40+
this.passagePrefix = builder.passagePrefix;
41+
this.embeddingModel = builder.embeddingModel;
3442
this.lock = new ReentrantLock();
3543
}
3644

@@ -40,12 +48,18 @@ public Object id() {
4048
}
4149

4250
@Override
43-
public void add(Embedding query, AiMessage response) {
51+
public void add(SystemMessage systemMessage, UserMessage userMessage, AiMessage aiResponse) {
4452

45-
if (Objects.isNull(query) || Objects.isNull(response)) {
53+
if (Objects.isNull(userMessage) || Objects.isNull(aiResponse)) {
4654
return;
4755
}
4856

57+
String query;
58+
if (Objects.isNull(systemMessage) || Objects.isNull(systemMessage.text()) || systemMessage.text().isBlank())
59+
query = userMessage.text();
60+
else
61+
query = "%s%s%s".formatted(passagePrefix, systemMessage.text(), userMessage.text());
62+
4963
try {
5064

5165
lock.lock();
@@ -55,7 +69,7 @@ public void add(Embedding query, AiMessage response) {
5569
elements.remove(0);
5670
}
5771

58-
List<CacheRecord> update = new LinkedList<>();
72+
List<CacheRecord> items = new LinkedList<>();
5973
for (int i = 0; i < elements.size(); i++) {
6074

6175
var expiredTime = Date.from(elements.get(i).creation().plus(ttl));
@@ -64,23 +78,29 @@ public void add(Embedding query, AiMessage response) {
6478
if (currentTime.after(expiredTime))
6579
continue;
6680

67-
update.add(elements.get(i));
81+
items.add(elements.get(i));
6882
}
6983

70-
update.add(CacheRecord.of(query, response));
71-
store.updateCache(id, update);
84+
items.add(CacheRecord.of(embeddingModel.embed(query).content(), aiResponse));
85+
store.updateCache(id, items);
7286

7387
} finally {
7488
lock.unlock();
7589
}
7690
}
7791

7892
@Override
79-
public Optional<AiMessage> search(Embedding query) {
93+
public Optional<AiMessage> search(SystemMessage systemMessage, UserMessage userMessage) {
8094

81-
if (Objects.isNull(query))
95+
if (Objects.isNull(userMessage))
8296
return Optional.empty();
8397

98+
String query;
99+
if (Objects.isNull(systemMessage) || Objects.isNull(systemMessage.text()) || systemMessage.text().isBlank())
100+
query = userMessage.text();
101+
else
102+
query = "%s%s%s".formatted(queryPrefix, systemMessage.text(), userMessage.text());
103+
84104
var elements = store.getAll(id);
85105
double maxScore = 0;
86106
AiMessage result = null;
@@ -95,7 +115,7 @@ public Optional<AiMessage> search(Embedding query) {
95115
continue;
96116
}
97117

98-
var relevanceScore = CosineSimilarity.between(query, cacheRecord.embedded());
118+
var relevanceScore = CosineSimilarity.between(embeddingModel.embed(query).content(), cacheRecord.embedded());
99119
var score = (float) CosineSimilarity.fromRelevanceScore(relevanceScore);
100120

101121
if (score >= threshold.doubleValue() && score >= maxScore) {
@@ -119,6 +139,9 @@ public static class Builder {
119139
AiCacheStore store;
120140
Double threshold;
121141
Duration ttl;
142+
String queryPrefix;
143+
String passagePrefix;
144+
EmbeddingModel embeddingModel;
122145

123146
private Builder(Object id) {
124147
this.id = id;
@@ -148,6 +171,21 @@ public Builder ttl(Duration ttl) {
148171
return this;
149172
}
150173

174+
public Builder queryPrefix(String queryPrefix) {
175+
this.queryPrefix = queryPrefix;
176+
return this;
177+
}
178+
179+
public Builder passagePrefix(String passagePrefix) {
180+
this.passagePrefix = passagePrefix;
181+
return this;
182+
}
183+
184+
public Builder embeddingModel(EmbeddingModel embeddingModel) {
185+
this.embeddingModel = embeddingModel;
186+
return this;
187+
}
188+
151189
public AiCache build() {
152190
return new MessageWindowAiCache(this);
153191
}

0 commit comments

Comments
 (0)