51
51
import org .objectweb .asm .tree .analysis .AnalyzerException ;
52
52
53
53
import dev .langchain4j .exception .IllegalConfigurationException ;
54
- import dev .langchain4j .service .Moderate ;
55
54
import io .quarkiverse .langchain4j .ModelName ;
56
55
import io .quarkiverse .langchain4j .ToolBox ;
57
56
import io .quarkiverse .langchain4j .deployment .items .SelectedChatModelProviderBuildItem ;
@@ -185,6 +184,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
185
184
186
185
Set <String > chatModelNames = new HashSet <>();
187
186
Set <String > moderationModelNames = new HashSet <>();
187
+
188
188
for (AnnotationInstance instance : index .getAnnotations (LangChain4jDotNames .REGISTER_AI_SERVICES )) {
189
189
if (instance .target ().kind () != AnnotationTarget .Kind .CLASS ) {
190
190
continue ; // should never happen
@@ -206,14 +206,12 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
206
206
}
207
207
208
208
String chatModelName = NamedConfigUtil .DEFAULT_NAME ;
209
+ String moderationModelName = NamedConfigUtil .DEFAULT_NAME ;
210
+ String embeddingModelName = getModelName (instance .value ("modelName" ));
211
+
209
212
if (chatLanguageModelSupplierClassDotName == null ) {
210
213
AnnotationValue modelNameValue = instance .value ("modelName" );
211
- if (modelNameValue != null ) {
212
- String modelNameValueStr = modelNameValue .asString ();
213
- if ((modelNameValueStr != null ) && !modelNameValueStr .isEmpty ()) {
214
- chatModelName = modelNameValueStr ;
215
- }
216
- }
214
+ chatModelName = getModelName (modelNameValue );
217
215
chatModelNames .add (chatModelName );
218
216
}
219
217
@@ -239,6 +237,18 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
239
237
}
240
238
}
241
239
240
+ // the default value depends on whether tools exists or not - if they do, then we require a AiCacheProvider bean
241
+ DotName aiCacheProviderSupplierClassDotName = LangChain4jDotNames .BEAN_AI_CACHE_PROVIDER_SUPPLIER ;
242
+ AnnotationValue aiCacheProviderSupplierValue = instance .value ("cacheProviderSupplier" );
243
+ if (aiCacheProviderSupplierValue != null ) {
244
+ aiCacheProviderSupplierClassDotName = aiCacheProviderSupplierValue .asClass ().name ();
245
+ if (!aiCacheProviderSupplierClassDotName
246
+ .equals (LangChain4jDotNames .BEAN_AI_CACHE_PROVIDER_SUPPLIER )) {
247
+ validateSupplierAndRegisterForReflection (aiCacheProviderSupplierClassDotName , index ,
248
+ reflectiveClassProducer );
249
+ }
250
+ }
251
+
242
252
DotName retrieverClassDotName = null ;
243
253
AnnotationValue retrieverValue = instance .value ("retriever" );
244
254
if (retrieverValue != null ) {
@@ -292,17 +302,11 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
292
302
}
293
303
294
304
// determine whether the method is annotated with @Moderate
295
- String moderationModelName = NamedConfigUtil .DEFAULT_NAME ;
296
305
for (MethodInfo method : declarativeAiServiceClassInfo .methods ()) {
297
306
if (method .hasAnnotation (LangChain4jDotNames .MODERATE )) {
298
307
if (moderationModelSupplierClassName .equals (LangChain4jDotNames .BEAN_IF_EXISTS_MODERATION_MODEL_SUPPLIER )) {
299
308
AnnotationValue modelNameValue = instance .value ("modelName" );
300
- if (modelNameValue != null ) {
301
- String modelNameValueStr = modelNameValue .asString ();
302
- if ((modelNameValueStr != null ) && !modelNameValueStr .isEmpty ()) {
303
- moderationModelName = modelNameValueStr ;
304
- }
305
- }
309
+ moderationModelName = getModelName (modelNameValue );
306
310
moderationModelNames .add (moderationModelName );
307
311
}
308
312
break ;
@@ -321,13 +325,16 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
321
325
chatLanguageModelSupplierClassDotName ,
322
326
toolDotNames ,
323
327
chatMemoryProviderSupplierClassDotName ,
328
+ aiCacheProviderSupplierClassDotName ,
324
329
retrieverClassDotName ,
325
330
retrievalAugmentorSupplierClassName ,
326
331
customRetrievalAugmentorSupplierClassIsABean ,
327
332
auditServiceSupplierClassName ,
328
333
moderationModelSupplierClassName ,
329
334
cdiScope ,
330
- chatModelName , moderationModelName ));
335
+ chatModelName ,
336
+ moderationModelName ,
337
+ embeddingModelName ));
331
338
}
332
339
333
340
for (String chatModelName : chatModelNames ) {
@@ -361,7 +368,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
361
368
List <DeclarativeAiServiceBuildItem > declarativeAiServiceItems ,
362
369
List <SelectedChatModelProviderBuildItem > selectedChatModelProvider ,
363
370
BuildProducer <SyntheticBeanBuildItem > syntheticBeanProducer ,
364
- BuildProducer <UnremovableBeanBuildItem > unremoveableProducer ) {
371
+ BuildProducer <UnremovableBeanBuildItem > unremoveableProducer ,
372
+ AiCacheBuildItem aiCacheBuildItem ) {
365
373
366
374
boolean needsChatModelBean = false ;
367
375
boolean needsStreamingChatModelBean = false ;
@@ -370,6 +378,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
370
378
boolean needsRetrievalAugmentorBean = false ;
371
379
boolean needsAuditServiceBean = false ;
372
380
boolean needsModerationModelBean = false ;
381
+ boolean needsAiCacheProvider = false ;
382
+
373
383
Set <DotName > allToolNames = new HashSet <>();
374
384
375
385
for (DeclarativeAiServiceBuildItem bi : declarativeAiServiceItems ) {
@@ -386,6 +396,10 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
386
396
? bi .getChatMemoryProviderSupplierClassDotName ().toString ()
387
397
: null ;
388
398
399
+ String aiCacheProviderSupplierClassName = bi .getAiCacheProviderSupplierClassDotName () != null
400
+ ? bi .getAiCacheProviderSupplierClassDotName ().toString ()
401
+ : null ;
402
+
389
403
String retrieverClassName = bi .getRetrieverClassDotName () != null
390
404
? bi .getRetrieverClassDotName ().toString ()
391
405
: null ;
@@ -403,7 +417,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
403
417
: null );
404
418
405
419
// determine whether the method returns Multi<String>
406
- boolean injectStreamingChatModelBean = false ;
420
+ boolean needsStreamingChatModel = false ;
407
421
for (MethodInfo method : declarativeAiServiceClassInfo .methods ()) {
408
422
if (!LangChain4jDotNames .MULTI .equals (method .returnType ().name ())) {
409
423
continue ;
@@ -419,29 +433,36 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
419
433
throw illegalConfiguration ("Only Multi<String> is supported as a Multi return type. Offending method is '"
420
434
+ method .declaringClass ().name ().toString () + "#" + method .name () + "'" );
421
435
}
422
- injectStreamingChatModelBean = true ;
436
+ needsStreamingChatModel = true ;
423
437
}
424
438
425
- boolean injectModerationModelBean = false ;
439
+ boolean needsModerationModel = false ;
426
440
for (MethodInfo method : declarativeAiServiceClassInfo .methods ()) {
427
- if (method .hasAnnotation (Moderate . class )) {
428
- injectModerationModelBean = true ;
441
+ if (method .hasAnnotation (LangChain4jDotNames . MODERATE )) {
442
+ needsModerationModel = true ;
429
443
break ;
430
444
}
431
445
}
432
446
433
447
String chatModelName = bi .getChatModelName ();
434
448
String moderationModelName = bi .getModerationModelName ();
449
+ String embeddingModelName = bi .getEmbeddingModelName ();
450
+ boolean enableCache = aiCacheBuildItem .isEnable ();
451
+
435
452
SyntheticBeanBuildItem .ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
436
453
.configure (QuarkusAiServiceContext .class )
437
454
.forceApplicationClass ()
438
455
.createWith (recorder .createDeclarativeAiService (
439
456
new DeclarativeAiServiceCreateInfo (serviceClassName , chatLanguageModelSupplierClassName ,
440
- toolClassNames , chatMemoryProviderSupplierClassName , retrieverClassName ,
457
+ toolClassNames , chatMemoryProviderSupplierClassName , aiCacheProviderSupplierClassName ,
458
+ retrieverClassName ,
441
459
retrievalAugmentorSupplierClassName ,
442
460
auditServiceClassSupplierName , moderationModelSupplierClassName , chatModelName ,
443
461
moderationModelName ,
444
- injectStreamingChatModelBean , injectModerationModelBean )))
462
+ embeddingModelName ,
463
+ needsStreamingChatModel ,
464
+ needsModerationModel ,
465
+ enableCache )))
445
466
.setRuntimeInit ()
446
467
.addQualifier ()
447
468
.annotation (LangChain4jDotNames .QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER ).addValue ("value" , serviceClassName )
@@ -451,15 +472,15 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
451
472
if ((chatLanguageModelSupplierClassName == null ) && !selectedChatModelProvider .isEmpty ()) {
452
473
if (NamedConfigUtil .isDefault (chatModelName )) {
453
474
configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .CHAT_MODEL ));
454
- if (injectStreamingChatModelBean ) {
475
+ if (needsStreamingChatModel ) {
455
476
configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .STREAMING_CHAT_MODEL ));
456
477
needsStreamingChatModelBean = true ;
457
478
}
458
479
} else {
459
480
configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .CHAT_MODEL ),
460
481
AnnotationInstance .builder (ModelName .class ).add ("value" , chatModelName ).build ());
461
482
462
- if (injectStreamingChatModelBean ) {
483
+ if (needsStreamingChatModel ) {
463
484
configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .STREAMING_CHAT_MODEL ),
464
485
AnnotationInstance .builder (ModelName .class ).add ("value" , chatModelName ).build ());
465
486
needsStreamingChatModelBean = true ;
@@ -515,7 +536,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
515
536
}
516
537
517
538
if (LangChain4jDotNames .BEAN_IF_EXISTS_MODERATION_MODEL_SUPPLIER .toString ()
518
- .equals (moderationModelSupplierClassName ) && injectModerationModelBean ) {
539
+ .equals (moderationModelSupplierClassName ) && needsModerationModel ) {
519
540
520
541
if (NamedConfigUtil .isDefault (moderationModelName )) {
521
542
configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .MODERATION_MODEL ));
@@ -527,6 +548,15 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
527
548
needsModerationModelBean = true ;
528
549
}
529
550
551
+ if (enableCache ) {
552
+ if (LangChain4jDotNames .BEAN_AI_CACHE_PROVIDER_SUPPLIER .toString ().equals (aiCacheProviderSupplierClassName )) {
553
+ configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .AI_CACHE_PROVIDER ));
554
+ }
555
+ configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .AI_CACHE_PROVIDER ));
556
+ configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .EMBEDDING_MODEL ));
557
+ needsAiCacheProvider = true ;
558
+ }
559
+
530
560
syntheticBeanProducer .produce (configurator .done ());
531
561
}
532
562
@@ -551,6 +581,10 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
551
581
if (needsModerationModelBean ) {
552
582
unremoveableProducer .produce (UnremovableBeanBuildItem .beanTypes (LangChain4jDotNames .MODERATION_MODEL ));
553
583
}
584
+ if (needsAiCacheProvider ) {
585
+ unremoveableProducer .produce (UnremovableBeanBuildItem .beanTypes (LangChain4jDotNames .AI_CACHE_PROVIDER ));
586
+ unremoveableProducer .produce (UnremovableBeanBuildItem .beanTypes (LangChain4jDotNames .EMBEDDING_MODEL ));
587
+ }
554
588
if (!allToolNames .isEmpty ()) {
555
589
unremoveableProducer .produce (UnremovableBeanBuildItem .beanTypes (allToolNames ));
556
590
}
@@ -870,6 +904,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method, boolea
870
904
}
871
905
872
906
boolean requiresModeration = method .hasAnnotation (LangChain4jDotNames .MODERATE );
907
+ boolean requiresCache = method .declaringClass ().hasDeclaredAnnotation (LangChain4jDotNames .CACHE_RESULT )
908
+ || method .hasDeclaredAnnotation (LangChain4jDotNames .CACHE_RESULT );
873
909
874
910
List <MethodParameterInfo > params = method .parameters ();
875
911
@@ -887,7 +923,7 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method, boolea
887
923
List <String > methodToolClassNames = gatherMethodToolClassNames (method );
888
924
889
925
return new AiServiceMethodCreateInfo (method .declaringClass ().name ().toString (), method .name (), systemMessageInfo ,
890
- userMessageInfo , memoryIdParamPosition , requiresModeration ,
926
+ userMessageInfo , memoryIdParamPosition , requiresModeration , requiresCache ,
891
927
returnType , metricsTimedInfo , metricsCountedInfo , spanInfo , methodToolClassNames );
892
928
}
893
929
@@ -1222,6 +1258,16 @@ static Map<String, Integer> toNameToArgsPositionMap(List<TemplateParameterInfo>
1222
1258
}
1223
1259
}
1224
1260
1261
+ private String getModelName (AnnotationValue value ) {
1262
+ if (value != null ) {
1263
+ String modelNameValueStr = value .asString ();
1264
+ if ((modelNameValueStr != null ) && !modelNameValueStr .isEmpty ()) {
1265
+ return modelNameValueStr ;
1266
+ }
1267
+ }
1268
+ return NamedConfigUtil .DEFAULT_NAME ;
1269
+ }
1270
+
1225
1271
public static final class AiServicesMethodBuildItem extends MultiBuildItem {
1226
1272
1227
1273
private final MethodInfo methodInfo ;
0 commit comments