Skip to content

Commit 774216e

Browse files
committed
Generate the id for the Cache with a unique context if possible
1 parent 631e904 commit 774216e

File tree

10 files changed

+101
-76
lines changed

10 files changed

+101
-76
lines changed

core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
import io.quarkiverse.langchain4j.runtime.cache.AiCache;
2424
import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider;
2525
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore;
26+
import io.quarkiverse.langchain4j.runtime.cache.FixedAiCache;
2627
import io.quarkiverse.langchain4j.runtime.cache.InMemoryAiCacheStore;
27-
import io.quarkiverse.langchain4j.runtime.cache.MessageWindowAiCache;
2828

2929
/**
3030
* Used to create LangChain4j's {@link AiServices} in a declarative manner that the application can then use simply by using the
@@ -101,7 +101,7 @@
101101
* Configures the way to obtain the {@link AiCacheProvider}.
102102
* <p>
103103
* Be default, Quarkus configures a {@link AiCacheProvider} bean that uses a {@link InMemoryAiCacheStore} bean as the
104-
* backing store. The default type for the actual {@link AiCache} is {@link MessageWindowAiCache} and it is configured with
104+
* backing store. The default type for the actual {@link AiCache} is {@link FixedAiCache} and it is configured with
105105
* the value of the {@code quarkus.langchain4j.cache.max-size} configuration property (which default to
106106
* 1) as a way of limiting the number of messages in each cache.
107107
* <p>

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiCacheRecorder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import io.quarkiverse.langchain4j.runtime.cache.AiCache;
99
import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider;
1010
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore;
11-
import io.quarkiverse.langchain4j.runtime.cache.MessageWindowAiCache;
11+
import io.quarkiverse.langchain4j.runtime.cache.FixedAiCache;
1212
import io.quarkiverse.langchain4j.runtime.cache.config.AiCacheConfig;
1313
import io.quarkus.arc.SyntheticCreationalContext;
1414
import io.quarkus.runtime.annotations.Recorder;
@@ -41,7 +41,7 @@ public AiCacheProvider apply(SyntheticCreationalContext<AiCacheProvider> context
4141
return new AiCacheProvider() {
4242
@Override
4343
public AiCache get(Object memoryId) {
44-
return MessageWindowAiCache.Builder
44+
return FixedAiCache.Builder
4545
.create(memoryId)
4646
.ttl(ttl)
4747
.maxSize(maxSize)

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,8 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
252252
}
253253
}
254254

255-
aiServiceContext.aiCaches = new ConcurrentHashMap<>();
255+
if (aiServiceContext.aiCaches == null)
256+
aiServiceContext.aiCaches = new ConcurrentHashMap<>();
256257
}
257258

258259
return (T) aiServiceContext;

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,9 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob
122122
Object memoryId = memoryId(methodCreateInfo, methodArgs, context.chatMemoryProvider != null);
123123
AiCache cache = null;
124124

125-
// TODO: REMOVE THIS COMMENT BEFORE MERGING THE PR.
126-
// - Understand how to implement the concept of cache for the stream responses.
127-
// - What do we have to do when we have the tools?
128-
129125
if (methodCreateInfo.isRequiresCache()) {
130-
Object cacheId = cacheId(methodCreateInfo, methodArgs);
131-
cache = context.aiCacheProvider.get(cacheId);
126+
Object cacheId = cacheId(methodCreateInfo);
127+
cache = context.cache(cacheId);
132128
}
133129
if (context.retrievalAugmentor != null) { // TODO extract method/class
134130
List<ChatMessage> chatMemory = context.hasChatMemory()
@@ -396,17 +392,15 @@ private static Object memoryId(AiServiceMethodCreateInfo createInfo, Object[] me
396392
return "default";
397393
}
398394

399-
private static Object cacheId(AiServiceMethodCreateInfo createInfo, Object[] methodArgs) {
395+
private static Object cacheId(AiServiceMethodCreateInfo createInfo) {
400396
for (DefaultMemoryIdProvider provider : DEFAULT_MEMORY_ID_PROVIDERS) {
401397
Object memoryId = provider.getMemoryId();
402398
if (memoryId != null) {
403399
String perServiceSuffix = "#" + createInfo.getInterfaceName() + "." + createInfo.getMethodName();
404400
return memoryId + perServiceSuffix;
405401
}
406402
}
407-
408-
// fallback to the default since there is nothing else we can really use here
409-
return "default";
403+
return "#" + createInfo.getInterfaceName() + "." + createInfo.getMethodName();
410404
}
411405

412406
// TODO: share these methods with LangChain4j

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ public boolean hasCache() {
2929
return aiCaches != null;
3030
}
3131

32-
public AiCache cache(Object memoryId) {
33-
return aiCaches.computeIfAbsent(memoryId, ignored -> aiCacheProvider.get(memoryId));
32+
public AiCache cache(Object cacheId) {
33+
return aiCaches.computeIfAbsent(cacheId, ignored -> aiCacheProvider.get(cacheId));
3434
}
3535

3636
/**
@@ -58,7 +58,7 @@ private void clearAiCache() {
5858
if (aiCaches != null) {
5959
aiCaches.forEach(new BiConsumer<>() {
6060
@Override
61-
public void accept(Object memoryId, AiCache aiCache) {
61+
public void accept(Object cacheId, AiCache aiCache) {
6262
aiCache.clear();
6363
}
6464
});
Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import java.time.Duration;
44
import java.util.Date;
5-
import java.util.LinkedList;
65
import java.util.List;
76
import java.util.Objects;
87
import java.util.Optional;
98
import java.util.concurrent.locks.ReentrantLock;
9+
import java.util.stream.Collectors;
1010

1111
import dev.langchain4j.data.message.AiMessage;
1212
import dev.langchain4j.data.message.SystemMessage;
@@ -16,9 +16,9 @@
1616
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore.CacheRecord;
1717

1818
/**
19-
* This {@link AiCache} implementation operates as a sliding window of messages.
19+
* This {@link AiCache} default implementation.
2020
*/
21-
public class MessageWindowAiCache implements AiCache {
21+
public class FixedAiCache implements AiCache {
2222

2323
private final Object id;
2424
private final Integer maxMessages;
@@ -30,7 +30,7 @@ public class MessageWindowAiCache implements AiCache {
3030
private final EmbeddingModel embeddingModel;
3131
private final ReentrantLock lock;
3232

33-
public MessageWindowAiCache(Builder builder) {
33+
public FixedAiCache(Builder builder) {
3434
this.id = builder.id;
3535
this.maxMessages = builder.maxSize;
3636
this.store = builder.store;
@@ -64,25 +64,17 @@ public void add(SystemMessage systemMessage, UserMessage userMessage, AiMessage
6464

6565
lock.lock();
6666

67-
List<CacheRecord> elements = store.getAll(id);
68-
if (elements.size() == maxMessages) {
69-
elements.remove(0);
70-
}
71-
72-
List<CacheRecord> items = new LinkedList<>();
73-
for (int i = 0; i < elements.size(); i++) {
74-
75-
var expiredTime = Date.from(elements.get(i).creation().plus(ttl));
76-
var currentTime = new Date();
77-
78-
if (currentTime.after(expiredTime))
79-
continue;
67+
List<CacheRecord> elements = store.getAll(id)
68+
.stream()
69+
.filter(this::checkTTL)
70+
.collect(Collectors.toList());
8071

81-
items.add(elements.get(i));
72+
if (elements.size() == maxMessages) {
73+
return;
8274
}
8375

84-
items.add(CacheRecord.of(embeddingModel.embed(query).content(), aiResponse));
85-
store.updateCache(id, items);
76+
elements.add(CacheRecord.of(embeddingModel.embed(query).content(), aiResponse));
77+
store.updateCache(id, elements);
8678

8779
} finally {
8880
lock.unlock();
@@ -101,30 +93,49 @@ public Optional<AiMessage> search(SystemMessage systemMessage, UserMessage userM
10193
else
10294
query = "%s%s%s".formatted(queryPrefix, systemMessage.text(), userMessage.text());
10395

104-
var elements = store.getAll(id);
105-
double maxScore = 0;
106-
AiMessage result = null;
96+
try {
10797

108-
for (var cacheRecord : elements) {
98+
lock.lock();
10999

110-
if (ttl != null) {
111-
var expiredTime = Date.from(cacheRecord.creation().plus(ttl));
112-
var currentTime = new Date();
100+
double maxScore = 0;
101+
AiMessage result = null;
102+
List<CacheRecord> records = store.getAll(id)
103+
.stream()
104+
.filter(this::checkTTL)
105+
.collect(Collectors.toList());
113106

114-
if (currentTime.after(expiredTime))
115-
continue;
116-
}
107+
for (var record : records) {
117108

118-
var relevanceScore = CosineSimilarity.between(embeddingModel.embed(query).content(), cacheRecord.embedded());
119-
var score = (float) CosineSimilarity.fromRelevanceScore(relevanceScore);
109+
var relevanceScore = CosineSimilarity.between(embeddingModel.embed(query).content(), record.embedded());
110+
var score = (float) CosineSimilarity.fromRelevanceScore(relevanceScore);
120111

121-
if (score >= threshold.doubleValue() && score >= maxScore) {
122-
maxScore = score;
123-
result = cacheRecord.response();
112+
if (score >= threshold.doubleValue() && score >= maxScore) {
113+
maxScore = score;
114+
result = record.response();
115+
}
124116
}
117+
118+
store.updateCache(id, records);
119+
return Optional.ofNullable(result);
120+
121+
} finally {
122+
lock.unlock();
123+
}
124+
}
125+
126+
private boolean checkTTL(CacheRecord record) {
127+
128+
if (ttl == null)
129+
return true;
130+
131+
var expiredTime = Date.from(record.creation().plus(ttl));
132+
var currentTime = new Date();
133+
134+
if (currentTime.after(expiredTime)) {
135+
return false;
125136
}
126137

127-
return Optional.ofNullable(result);
138+
return true;
128139
}
129140

130141
@Override
@@ -187,7 +198,7 @@ public Builder embeddingModel(EmbeddingModel embeddingModel) {
187198
}
188199

189200
public AiCache build() {
190-
return new MessageWindowAiCache(this);
201+
return new FixedAiCache(this);
191202
}
192203
}
193204
}

model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheClassTest.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,22 @@ public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
7676
@Test
7777
void cache_test() {
7878

79-
String cacheId = "default";
79+
String chatCacheId = "#" + LLMService.class.getName() + ".chat";
80+
String chat2CacheId = "#" + LLMService.class.getName() + ".chat2";
81+
82+
assertEquals(0, aiCacheStore.getAll(chatCacheId).size());
83+
assertEquals(0, aiCacheStore.getAll(chat2CacheId).size());
8084

81-
assertEquals(0, aiCacheStore.getAll(cacheId).size());
8285
service.chat("chat");
83-
assertEquals(1, aiCacheStore.getAll(cacheId).size());
86+
assertEquals(1, aiCacheStore.getAll(chatCacheId).size());
87+
assertEquals("result", aiCacheStore.getAll(chatCacheId).get(0).response().text());
88+
assertEquals(es, aiCacheStore.getAll(chatCacheId).get(0).embedded().vector());
89+
assertEquals(0, aiCacheStore.getAll(chat2CacheId).size());
8490

8591
service.chat2("chat2");
86-
assertEquals(1, aiCacheStore.getAll(cacheId).size());
87-
assertEquals("result", aiCacheStore.getAll(cacheId).get(0).response().text());
88-
assertEquals(es, aiCacheStore.getAll(cacheId).get(0).embedded().vector());
92+
assertEquals(1, aiCacheStore.getAll(chat2CacheId).size());
93+
assertEquals("result", aiCacheStore.getAll(chat2CacheId).get(0).response().text());
94+
assertEquals(es, aiCacheStore.getAll(chat2CacheId).get(0).embedded().vector());
8995
}
9096

9197
@Test

model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheConfigTest.java

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
import io.quarkiverse.langchain4j.CacheResult;
2727
import io.quarkiverse.langchain4j.RegisterAiService;
2828
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore;
29+
import io.quarkus.arc.Arc;
30+
import io.quarkus.arc.ArcContainer;
31+
import io.quarkus.arc.ManagedContext;
2932
import io.quarkus.test.QuarkusUnitTest;
3033

3134
public class CacheConfigTest {
@@ -85,7 +88,7 @@ else if (textSegments.get(0).text().equals("TESTFOURTH"))
8588
@Order(1)
8689
void cache_ttl_test() throws InterruptedException {
8790

88-
String cacheId = "default";
91+
String cacheId = "#" + LLMService.class.getName() + ".chat";
8992
aiCacheStore.deleteCache(cacheId);
9093

9194
service.chat("FIRST");
@@ -107,7 +110,7 @@ void cache_ttl_test() throws InterruptedException {
107110
@Order(2)
108111
void cache_max_size_test() {
109112

110-
String cacheId = "default";
113+
String cacheId = "#" + LLMService.class.getName() + ".chat";
111114
aiCacheStore.deleteCache(cacheId);
112115

113116
service.chat("FIRST");
@@ -119,12 +122,18 @@ void cache_max_size_test() {
119122
service.chat("THIRD");
120123
service.chat("FOURTH");
121124
assertEquals(3, aiCacheStore.getAll(cacheId).size());
122-
assertEquals("cache: TESTSECOND", aiCacheStore.getAll(cacheId).get(0).response().text());
123-
assertEquals(second, aiCacheStore.getAll(cacheId).get(0).embedded().vector());
124-
assertEquals("cache: TESTTHIRD", aiCacheStore.getAll(cacheId).get(1).response().text());
125-
assertEquals(third, aiCacheStore.getAll(cacheId).get(1).embedded().vector());
126-
assertEquals("cache: TESTFOURTH", aiCacheStore.getAll(cacheId).get(2).response().text());
127-
assertEquals(fourth, aiCacheStore.getAll(cacheId).get(2).embedded().vector());
125+
assertEquals("cache: TESTFIRST", aiCacheStore.getAll(cacheId).get(0).response().text());
126+
assertEquals(first, aiCacheStore.getAll(cacheId).get(0).embedded().vector());
127+
assertEquals("cache: TESTSECOND", aiCacheStore.getAll(cacheId).get(1).response().text());
128+
assertEquals(second, aiCacheStore.getAll(cacheId).get(1).embedded().vector());
129+
assertEquals("cache: TESTTHIRD", aiCacheStore.getAll(cacheId).get(2).response().text());
130+
assertEquals(third, aiCacheStore.getAll(cacheId).get(2).embedded().vector());
131+
}
132+
133+
private String getContext(String methodName) {
134+
ArcContainer container = Arc.container();
135+
ManagedContext requestContext = container.requestContext();
136+
return requestContext.getState() + "#" + LLMService.class.getName() + "." + methodName;
128137
}
129138

130139
static float[] first = {

model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheMethodTest.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,20 @@ public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
7979
@Test
8080
void cache_test() {
8181

82-
String cacheId = "default";
82+
String chatCacheId = "#" + LLMService.class.getName() + ".chat";
83+
String chatNoCacheCacheId = "#" + LLMService.class.getName() + ".chatNoCache";
8384

84-
assertEquals(0, aiCacheStore.getAll(cacheId).size());
85+
assertEquals(0, aiCacheStore.getAll(chatCacheId).size());
86+
assertEquals(0, aiCacheStore.getAll(chatNoCacheCacheId).size());
8587
service.chatNoCache("noCache");
86-
assertEquals(0, aiCacheStore.getAll(cacheId).size());
88+
assertEquals(0, aiCacheStore.getAll(chatCacheId).size());
89+
assertEquals(0, aiCacheStore.getAll(chatNoCacheCacheId).size());
8790

8891
service.chat("cache");
89-
assertEquals(1, aiCacheStore.getAll(cacheId).size());
90-
assertEquals("result", aiCacheStore.getAll(cacheId).get(0).response().text());
91-
assertEquals(es, aiCacheStore.getAll(cacheId).get(0).embedded().vector());
92+
assertEquals(1, aiCacheStore.getAll(chatCacheId).size());
93+
assertEquals("result", aiCacheStore.getAll(chatCacheId).get(0).response().text());
94+
assertEquals(es, aiCacheStore.getAll(chatCacheId).get(0).embedded().vector());
95+
assertEquals(0, aiCacheStore.getAll(chatNoCacheCacheId).size());
9296
}
9397

9498
@Test

model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CachePrefixTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
9494
@Order(1)
9595
void cache_prefix_test() throws InterruptedException {
9696

97-
String cacheId = "default";
97+
String cacheId = "#" + LLMService.class.getName() + ".chat";
9898
aiCacheStore.deleteCache(cacheId);
9999

100100
service.chat("firstMessage");

0 commit comments

Comments
 (0)