Skip to content

Commit f59a07c

Browse files
committed
Make response handling more generic. Resulting in generic cost handling.
1 parent 2e8615a commit f59a07c

File tree

18 files changed

+327
-28
lines changed

18 files changed

+327
-28
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ListenersProcessor.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
import java.util.Optional;
44

5+
import jakarta.inject.Singleton;
6+
7+
import org.jboss.jandex.DotName;
8+
9+
import io.quarkiverse.langchain4j.cost.CostEstimatorResponseListener;
10+
import io.quarkiverse.langchain4j.deployment.config.LangChain4jBuildConfig;
511
import io.quarkiverse.langchain4j.runtime.listeners.MetricsChatModelListener;
612
import io.quarkiverse.langchain4j.runtime.listeners.SpanChatModelListener;
713
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
@@ -14,6 +20,20 @@
1420

1521
public class ListenersProcessor {
1622

23+
@BuildStep
24+
public void costListener(
25+
LangChain4jBuildConfig config,
26+
BuildProducer<AdditionalBeanBuildItem> additionalBeanProducer) {
27+
if (config.costListener()) {
28+
additionalBeanProducer.produce(
29+
AdditionalBeanBuildItem.builder()
30+
.addBeanClass(CostEstimatorResponseListener.class)
31+
.setDefaultScope(DotName.createSimple(Singleton.class))
32+
.setUnremovable()
33+
.build());
34+
}
35+
}
36+
1737
@BuildStep
1838
public void spanListeners(Capabilities capabilities,
1939
Optional<MetricsCapabilityBuildItem> metricsCapability,

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/LangChain4jBuildConfig.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ public interface LangChain4jBuildConfig {
4343
@WithDefault("true")
4444
boolean responseSchema();
4545

46+
/**
47+
* Configuration property to enable or disable generic cost listener
48+
*/
49+
@WithDefault("false")
50+
boolean costListener();
51+
4652
interface BaseConfig {
4753
/**
4854
* Chat model
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package io.quarkiverse.langchain4j.cost;
2+
3+
import java.util.ArrayList;
4+
import java.util.Comparator;
5+
import java.util.List;
6+
7+
import jakarta.inject.Inject;
8+
9+
import dev.langchain4j.model.output.TokenUsage;
10+
import io.quarkiverse.langchain4j.response.ResponseListener;
11+
import io.quarkiverse.langchain4j.response.ResponseRecord;
12+
import io.quarkus.arc.All;
13+
import io.smallrye.common.annotation.Experimental;
14+
15+
/**
16+
* Allows for user code to provide a custom strategy for estimating the cost of API calls
17+
*/
18+
@Experimental("This feature is experimental and the API is subject to change")
19+
public class CostEstimatorResponseListener implements ResponseListener {
20+
21+
private final CostEstimatorService service;
22+
private final List<CostListener> listeners;
23+
24+
@Inject
25+
public CostEstimatorResponseListener(CostEstimatorService service, @All List<CostListener> listeners) {
26+
this.service = service;
27+
this.listeners = new ArrayList<>(listeners);
28+
this.listeners.sort(Comparator.comparingInt(CostListener::order));
29+
}
30+
31+
@Override
32+
public void onResponse(ResponseRecord rr) {
33+
String model = rr.model();
34+
TokenUsage tokenUsage = rr.tokenUsage();
35+
CostEstimator.CostContext context = new MyCostContext(tokenUsage, model);
36+
Cost cost = service.estimate(context);
37+
if (cost != null) {
38+
for (CostListener cl : listeners) {
39+
cl.handleCost(model, tokenUsage, cost);
40+
}
41+
}
42+
}
43+
44+
private record MyCostContext(TokenUsage tokenUsage, String model) implements CostEstimator.CostContext {
45+
@Override
46+
public Integer inputTokens() {
47+
return tokenUsage().inputTokenCount();
48+
}
49+
50+
@Override
51+
public Integer outputTokens() {
52+
return tokenUsage().outputTokenCount();
53+
}
54+
}
55+
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorService.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@ public CostEstimatorService(@All List<CostEstimator> costEstimators) {
2828
public Cost estimate(ChatModelResponseContext response) {
2929
TokenUsage tokenUsage = response.response().tokenUsage();
3030
CostEstimator.CostContext costContext = new MyCostContext(tokenUsage, response);
31+
return estimate(costContext);
32+
}
3133

34+
public Cost estimate(CostEstimator.CostContext context) {
3235
for (CostEstimator costEstimator : costEstimators) {
33-
if (costEstimator.supports(costContext)) {
34-
CostEstimator.CostResult costResult = costEstimator.estimate(costContext);
36+
if (costEstimator.supports(context)) {
37+
CostEstimator.CostResult costResult = costEstimator.estimate(context);
3538
if (costResult != null) {
3639
BigDecimal totalCost = costResult.inputTokensCost().add(costResult.outputTokensCost());
3740
return new Cost(totalCost, costResult.currency());
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package io.quarkiverse.langchain4j.cost;
2+
3+
import dev.langchain4j.model.output.TokenUsage;
4+
5+
/**
6+
* Allows for user code to handle estimate cost; e.g. some simple accounting
7+
*/
8+
public interface CostListener {
9+
void handleCost(String model, TokenUsage tokenUsage, Cost cost);
10+
11+
default int order() {
12+
return 0;
13+
}
14+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package io.quarkiverse.langchain4j.response;
2+
3+
import java.util.Map;
4+
5+
import jakarta.annotation.Priority;
6+
import jakarta.interceptor.AroundInvoke;
7+
import jakarta.interceptor.Interceptor;
8+
import jakarta.interceptor.InvocationContext;
9+
10+
import dev.langchain4j.data.message.AiMessage;
11+
import dev.langchain4j.model.chat.listener.ChatModelResponse;
12+
import dev.langchain4j.model.chat.response.ChatResponse;
13+
import dev.langchain4j.model.output.Response;
14+
15+
/**
16+
* Simple (Chat)Response interceptor, to be applied directly on the model.
17+
*/
18+
@Interceptor
19+
@ResponseInterceptorBinding
20+
@Priority(0)
21+
public class ResponseInterceptor extends ResponseInterceptorBase {
22+
23+
@AroundInvoke
24+
public Object intercept(InvocationContext context) throws Exception {
25+
Object result = context.proceed();
26+
ResponseRecord rr = null;
27+
if (result instanceof Response<?> response) {
28+
Object content = response.content();
29+
if (content instanceof AiMessage am) {
30+
rr = new ResponseRecord(getModel(context.getTarget()), am, response.tokenUsage(), response.finishReason(),
31+
response.metadata());
32+
}
33+
} else if (result instanceof ChatResponse response) {
34+
rr = new ResponseRecord(getModel(context.getTarget()), response.aiMessage(), response.tokenUsage(),
35+
response.finishReason(), Map.of());
36+
} else if (result instanceof ChatModelResponse response) {
37+
rr = new ResponseRecord(response.model(), response.aiMessage(), response.tokenUsage(), response.finishReason(),
38+
Map.of("id", response.id()));
39+
}
40+
if (rr != null) {
41+
for (ResponseListener l : getListeners()) {
42+
l.onResponse(rr);
43+
}
44+
}
45+
return result;
46+
}
47+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package io.quarkiverse.langchain4j.response;
2+
3+
import java.lang.reflect.Method;
4+
import java.util.Comparator;
5+
import java.util.List;
6+
7+
import jakarta.enterprise.inject.Any;
8+
import jakarta.enterprise.inject.spi.CDI;
9+
10+
/**
11+
* Simple (Chat)Response interceptor base, to be applied directly on the model.
12+
*/
13+
public abstract class ResponseInterceptorBase {
14+
15+
private volatile String model;
16+
private volatile List<ResponseListener> listeners;
17+
18+
// TODO -- uh uh ... reflection ... puke
19+
protected String getModel(Object target) {
20+
if (model == null) {
21+
try {
22+
Class<?> clazz = target.getClass();
23+
Method method = clazz.getMethod("modelName");
24+
model = (String) method.invoke(target);
25+
} catch (Exception e) {
26+
throw new RuntimeException(e);
27+
}
28+
}
29+
return model;
30+
}
31+
32+
protected List<ResponseListener> getListeners() {
33+
if (listeners == null) {
34+
listeners = CDI.current().select(ResponseListener.class, Any.Literal.INSTANCE)
35+
.stream()
36+
.sorted(Comparator.comparing(ResponseListener::order))
37+
.toList();
38+
}
39+
return listeners;
40+
}
41+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package io.quarkiverse.langchain4j.response;
2+
3+
import java.lang.annotation.ElementType;
4+
import java.lang.annotation.Retention;
5+
import java.lang.annotation.RetentionPolicy;
6+
import java.lang.annotation.Target;
7+
8+
import jakarta.interceptor.InterceptorBinding;
9+
10+
@InterceptorBinding
11+
@Target({ ElementType.TYPE, ElementType.METHOD })
12+
@Retention(RetentionPolicy.RUNTIME)
13+
public @interface ResponseInterceptorBinding {
14+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package io.quarkiverse.langchain4j.response;
2+
3+
@ResponseInterceptorBinding
4+
public abstract class ResponseInterceptorBindingSource {
5+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package io.quarkiverse.langchain4j.response;
2+
3+
/**
4+
* Simple ResponseRecord listener, to be implemented by the (advanced) users.
5+
*/
6+
public interface ResponseListener {
7+
void onResponse(ResponseRecord response);
8+
9+
default int order() {
10+
return 0;
11+
}
12+
}

0 commit comments

Comments
 (0)