Skip to content

Commit 77f6a00

Browse files
committed
Add fine-grained control for MCP server(s) provided tools
1 parent adaaea8 commit 77f6a00

File tree

17 files changed

+733
-248
lines changed

17 files changed

+733
-248
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,6 +1375,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
13751375
toolQualifierProviders.stream().map(
13761376
ToolQualifierProvider.BuildItem::getProvider).toList());
13771377

1378+
List<String> methodMcpClientNames = gatherMethodMcpClientNames(method);
1379+
13781380
List<String> outputGuardrails = AiServicesMethodBuildItem.gatherGuardrails(method, OUTPUT_GUARDRAILS);
13791381
List<String> inputGuardrails = AiServicesMethodBuildItem.gatherGuardrails(method, INPUT_GUARDRAILS);
13801382

@@ -1390,8 +1392,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
13901392
userMessageInfo, memoryIdParamPosition, requiresModeration,
13911393
returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)),
13921394
overrideChatModelParamPosition, metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo,
1393-
methodToolClassInfo, switchToWorkerThreadForToolExecution, inputGuardrails, outputGuardrails,
1394-
accumulatorClassName, responseAugmenterClassName);
1395+
methodToolClassInfo, methodMcpClientNames, switchToWorkerThreadForToolExecution, inputGuardrails,
1396+
outputGuardrails, accumulatorClassName, responseAugmenterClassName);
13951397
}
13961398

13971399
private Optional<JsonSchema> jsonSchemaFrom(java.lang.reflect.Type returnType) {
@@ -1839,6 +1841,26 @@ private List<String> gatherMethodToolClassNames(MethodInfo method) {
18391841
return Arrays.stream(toolClasses).map(t -> t.name().toString()).collect(Collectors.toList());
18401842
}
18411843

1844+
private List<String> gatherMethodMcpClientNames(MethodInfo method) {
1845+
// Using the class name to keep the McpToolBox annotation in the mcp module
1846+
AnnotationInstance mcpToolBoxInstance = method.declaredAnnotation("io.quarkiverse.langchain4j.mcp.runtime.McpToolBox");
1847+
if (mcpToolBoxInstance == null) {
1848+
return null;
1849+
}
1850+
1851+
AnnotationValue mcpToolBoxValue = mcpToolBoxInstance.value();
1852+
if (mcpToolBoxValue == null) {
1853+
return Collections.emptyList();
1854+
}
1855+
1856+
String[] mcpClientNames = mcpToolBoxValue.asStringArray();
1857+
if (mcpClientNames.length == 0) {
1858+
return Collections.emptyList();
1859+
}
1860+
1861+
return Arrays.asList(mcpClientNames);
1862+
}
1863+
18421864
private DotName determineChatMemorySeeder(ClassInfo iface, ClassOutput classOutput) {
18431865
List<AnnotationInstance> annotations = iface.annotations(SEED_MEMORY);
18441866
if (annotations.isEmpty()) {

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ public final class AiServiceMethodCreateInfo {
4141
private final Optional<SpanInfo> spanInfo;
4242
// support @Toolbox
4343
private final Map<String, AnnotationLiteral<?>> toolClassInfo;
44+
private final List<String> mcpClientNames;
4445
private final ResponseSchemaInfo responseSchemaInfo;
4546

4647
// support for guardrails
@@ -78,6 +79,7 @@ public AiServiceMethodCreateInfo(String interfaceName, String methodName,
7879
Optional<SpanInfo> spanInfo,
7980
ResponseSchemaInfo responseSchemaInfo,
8081
Map<String, AnnotationLiteral<?>> toolClassInfo,
82+
List<String> mcpClientNames,
8183
boolean switchToWorkerThreadForToolExecution,
8284
List<String> inputGuardrailsClassNames,
8385
List<String> outputGuardrailsClassNames,
@@ -102,6 +104,7 @@ public Type get() {
102104
this.spanInfo = spanInfo;
103105
this.responseSchemaInfo = responseSchemaInfo;
104106
this.toolClassInfo = toolClassInfo;
107+
this.mcpClientNames = mcpClientNames;
105108
this.inputGuardrailsClassNames = inputGuardrailsClassNames;
106109
this.outputGuardrailsClassNames = outputGuardrailsClassNames;
107110
this.outputTokenAccumulatorClassName = outputTokenAccumulatorClassName;
@@ -173,6 +176,10 @@ public Map<String, AnnotationLiteral<?>> getToolClassInfo() {
173176
return toolClassInfo;
174177
}
175178

179+
public List<String> getMcpClientNames() {
180+
return mcpClientNames;
181+
}
182+
176183
public List<ToolSpecification> getToolSpecifications() {
177184
return toolSpecifications;
178185
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob
192192
if (context.toolService.toolProvider() != null) {
193193
toolSpecifications = toolSpecifications != null ? new ArrayList<>(toolSpecifications) : new ArrayList<>();
194194
toolExecutors = toolExecutors != null ? new HashMap<>(toolExecutors) : new HashMap<>();
195-
ToolProviderRequest request = new ToolProviderRequest(memoryId, userMessage);
195+
ToolProviderRequest request = new QuarkusToolProviderRequest(memoryId, userMessage,
196+
methodCreateInfo.getMcpClientNames());
196197
ToolProviderResult result = context.toolService.toolProvider().provideTools(request);
197198
for (ToolSpecification specification : result.tools().keySet()) {
198199
toolSpecifications.add(specification);
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package io.quarkiverse.langchain4j.runtime.aiservice;
2+
3+
import java.util.List;
4+
5+
import dev.langchain4j.data.message.UserMessage;
6+
import dev.langchain4j.service.tool.ToolProviderRequest;
7+
8+
public class QuarkusToolProviderRequest extends ToolProviderRequest {
9+
10+
private final List<String> mcpClientNames;
11+
12+
public QuarkusToolProviderRequest(Object chatMemoryId, UserMessage userMessage, List<String> mcpClientNames) {
13+
super(chatMemoryId, userMessage);
14+
this.mcpClientNames = mcpClientNames;
15+
}
16+
17+
public List<String> getMcpClientNames() {
18+
return mcpClientNames;
19+
}
20+
}
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
package io.quarkiverse.langchain4j.mcp.test;
2+
3+
import java.util.Map;
4+
import java.util.Random;
5+
import java.util.concurrent.CompletableFuture;
6+
import java.util.concurrent.ConcurrentHashMap;
7+
import java.util.concurrent.ScheduledExecutorService;
8+
import java.util.concurrent.TimeUnit;
9+
import java.util.concurrent.atomic.AtomicLong;
10+
11+
import jakarta.inject.Inject;
12+
import jakarta.ws.rs.Consumes;
13+
import jakarta.ws.rs.POST;
14+
import jakarta.ws.rs.Path;
15+
import jakarta.ws.rs.Produces;
16+
import jakarta.ws.rs.core.Context;
17+
import jakarta.ws.rs.core.MediaType;
18+
import jakarta.ws.rs.core.Response;
19+
import jakarta.ws.rs.sse.Sse;
20+
import jakarta.ws.rs.sse.SseEventSink;
21+
22+
import org.jboss.logging.Logger;
23+
import org.jboss.resteasy.reactive.RestStreamElementType;
24+
25+
import com.fasterxml.jackson.databind.JsonNode;
26+
import com.fasterxml.jackson.databind.ObjectMapper;
27+
import com.fasterxml.jackson.databind.node.ObjectNode;
28+
29+
public abstract class AbstractMockHttpMcpServer {
30+
31+
private final AtomicLong ID_GENERATOR = new AtomicLong(new Random().nextLong(1000, 5000));
32+
33+
private static Logger logger = Logger.getLogger(MockHttpMcpServer.class);
34+
35+
private volatile boolean shouldRespondToPing = true;
36+
37+
// key = operation ID of the ping
38+
// value = future that will be completed when the ping response for that ID is received
39+
final Map<Long, CompletableFuture<Void>> pendingPings = new ConcurrentHashMap<>();
40+
41+
private volatile SseEventSink sink;
42+
private volatile Sse sse;
43+
private final ObjectMapper objectMapper = new ObjectMapper();
44+
private volatile boolean initializationNotificationReceived = false;
45+
46+
@Inject
47+
ScheduledExecutorService scheduledExecutorService;
48+
49+
@Path("/sse")
50+
@Produces(MediaType.SERVER_SENT_EVENTS)
51+
@RestStreamElementType(MediaType.TEXT_PLAIN)
52+
public void sse(@Context SseEventSink sink, @Context Sse sse) {
53+
this.sink = sink;
54+
this.sse = sse;
55+
sink.send(sse.newEventBuilder()
56+
.id("id")
57+
.name("endpoint")
58+
.mediaType(MediaType.TEXT_PLAIN_TYPE)
59+
.data("/" + getEndpoint() + "/post")
60+
.build());
61+
}
62+
63+
protected abstract String getEndpoint();
64+
65+
@Path("/post")
66+
@Consumes(MediaType.APPLICATION_JSON)
67+
@POST
68+
public Response post(JsonNode message) {
69+
if (message.get("method") != null) {
70+
String method = message.get("method").asText();
71+
if (method.equals("notifications/cancelled")) {
72+
return Response.ok().build();
73+
}
74+
if (method.equals("notifications/initialized")) {
75+
if (initializationNotificationReceived) {
76+
return Response.serverError().entity("Duplicate 'notifications/initialized' message").build();
77+
}
78+
initializationNotificationReceived = true;
79+
return Response.ok().build();
80+
}
81+
String operationId = message.get("id").asText();
82+
if (method.equals("initialize")) {
83+
initialize(operationId);
84+
} else if (method.equals("tools/list")) {
85+
ensureInitialized();
86+
listTools(operationId);
87+
} else if (method.equals("tools/call")) {
88+
ensureInitialized();
89+
if (message.get("params").get("name").asText().equals("add")) {
90+
executeAddOperation(message, operationId);
91+
} else if (message.get("params").get("name").asText().equals("logging")) {
92+
executeLoggingOperation(message, operationId);
93+
} else if (message.get("params").get("name").asText().equals("longRunningOperation")) {
94+
executeLongRunningOperation(message, operationId);
95+
} else {
96+
return Response.serverError().entity("Unknown operation").build();
97+
}
98+
99+
} else if (method.equals("ping")) {
100+
if (shouldRespondToPing) {
101+
ObjectNode result = buildPongMessage(operationId);
102+
sink.send(sse.newEventBuilder()
103+
.name("message")
104+
.data(result)
105+
.build());
106+
} else {
107+
logger.info("Ignoring ping request");
108+
}
109+
return Response.accepted().build();
110+
}
111+
} else {
112+
// if 'method' is null, the message is probably a ping response
113+
long id = message.get("id").asLong();
114+
CompletableFuture<Void> future = pendingPings.remove(id);
115+
if (future != null) {
116+
future.complete(null);
117+
} else {
118+
return Response.serverError().entity("Received a ping response with unknown ID " + id).build();
119+
}
120+
}
121+
return Response.accepted().build();
122+
}
123+
124+
private ObjectNode buildPongMessage(String operationId) {
125+
ObjectNode pong = objectMapper.createObjectNode();
126+
pong.put("jsonrpc", "2.0");
127+
pong.put("id", operationId);
128+
pong.put("result", objectMapper.createObjectNode());
129+
return pong;
130+
}
131+
132+
private void executeLoggingOperation(JsonNode message, String operationId) {
133+
ObjectNode logData = objectMapper.createObjectNode();
134+
logData.put("message", "This is a log message");
135+
ObjectNode log = buildLoggingMessage(logData);
136+
sink.send(sse.newEventBuilder()
137+
.name("message")
138+
.data(log)
139+
.build());
140+
ObjectNode result = buildToolResult(operationId, "OK");
141+
sink.send(sse.newEventBuilder()
142+
.name("message")
143+
.data(result)
144+
.build());
145+
}
146+
147+
private ObjectNode buildLoggingMessage(JsonNode message) {
148+
ObjectNode log = objectMapper.createObjectNode();
149+
log.put("jsonrpc", "2.0");
150+
log.put("method", "notifications/message");
151+
ObjectNode params = objectMapper.createObjectNode();
152+
log.set("params", params);
153+
params.put("level", "info");
154+
params.put("logger", getEndpoint());
155+
params.set("data", message);
156+
return log;
157+
}
158+
159+
private ObjectNode buildToolResult(String operationId, String result) {
160+
ObjectNode resultNode = objectMapper.createObjectNode();
161+
resultNode.put("id", operationId);
162+
resultNode.put("jsonrpc", "2.0");
163+
ObjectNode resultContent = objectMapper.createObjectNode();
164+
resultNode.set("result", resultContent);
165+
resultContent.putArray("content")
166+
.addObject()
167+
.put("type", "text")
168+
.put("text", result);
169+
return resultNode;
170+
}
171+
172+
// throw an exception if we haven't received the 'notifications/initialized' message yet
173+
private void ensureInitialized() {
174+
if (!initializationNotificationReceived) {
175+
throw new IllegalStateException("The client has not sent the 'notifications/initialized' message yet");
176+
}
177+
}
178+
179+
private void listTools(String operationId) {
180+
String response = getToolsListResponse().formatted(operationId);
181+
sink.send(sse.newEventBuilder()
182+
.name("message")
183+
.data(response)
184+
.build());
185+
}
186+
187+
protected abstract String getToolsListResponse();
188+
189+
private void initialize(String operationId) {
190+
ObjectNode initializeResponse = objectMapper.createObjectNode();
191+
initializeResponse
192+
.put("id", operationId)
193+
.put("jsonrpc", "2.0")
194+
.putObject("result")
195+
.put("protocolVersion", "2024-11-05");
196+
sink.send(sse.newEventBuilder()
197+
.name("message")
198+
.data(initializeResponse)
199+
.build());
200+
}
201+
202+
private void executeAddOperation(JsonNode message, String operationId) {
203+
int a = message.get("params").get("arguments").get("a").asInt();
204+
int b = message.get("params").get("arguments").get("b").asInt();
205+
int additionResult = a + b;
206+
ObjectNode result = buildToolResult(operationId, "The sum of " + a + " and " + b + " is " + additionResult + ".");
207+
sink.send(sse.newEventBuilder()
208+
.name("message")
209+
.data(result)
210+
.build());
211+
}
212+
213+
private void executeLongRunningOperation(JsonNode message, String operationId) {
214+
int duration = message.get("params").get("arguments").get("duration").asInt();
215+
scheduledExecutorService.schedule(() -> {
216+
ObjectNode result = buildToolResult(operationId, "Operation completed.");
217+
sink.send(sse.newEventBuilder()
218+
.name("message")
219+
.data(result)
220+
.build());
221+
}, duration, TimeUnit.SECONDS);
222+
}
223+
224+
long sendPing() {
225+
ObjectNode initializeResponse = objectMapper.createObjectNode();
226+
long id = ID_GENERATOR.incrementAndGet();
227+
initializeResponse
228+
.put("id", id)
229+
.put("jsonrpc", "2.0")
230+
.put("method", "ping");
231+
sink.send(sse.newEventBuilder()
232+
.name("message")
233+
.data(initializeResponse)
234+
.build());
235+
pendingPings.put(id, new CompletableFuture<>());
236+
return id;
237+
}
238+
239+
void stopRespondingToPings() {
240+
shouldRespondToPing = false;
241+
}
242+
}

mcp/deployment/src/test/java/io/quarkiverse/langchain4j/mcp/test/McpClaudeConfigTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public class McpClaudeConfigTest {
2727
@RegisterExtension
2828
static QuarkusUnitTest unitTest = new QuarkusUnitTest()
2929
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
30-
.addClass(MockHttpMcpServer.class)
30+
.addClasses(AbstractMockHttpMcpServer.class, MockHttpMcpServer.class)
3131
.addAsResource(new StringAsset("""
3232
{
3333
"mcpServers": {

mcp/deployment/src/test/java/io/quarkiverse/langchain4j/mcp/test/McpClientAndToolProviderCDITest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
import org.junit.jupiter.api.Test;
1212
import org.junit.jupiter.api.extension.RegisterExtension;
1313

14-
import dev.langchain4j.mcp.McpToolProvider;
1514
import dev.langchain4j.mcp.client.DefaultMcpClient;
1615
import dev.langchain4j.mcp.client.McpClient;
1716
import dev.langchain4j.service.tool.ToolProvider;
1817
import io.quarkiverse.langchain4j.mcp.runtime.McpClientName;
18+
import io.quarkiverse.langchain4j.mcp.runtime.QuarkusMcpToolProvider;
1919
import io.quarkus.arc.ClientProxy;
2020
import io.quarkus.test.QuarkusUnitTest;
2121

@@ -24,7 +24,7 @@ public class McpClientAndToolProviderCDITest {
2424
@RegisterExtension
2525
static QuarkusUnitTest unitTest = new QuarkusUnitTest()
2626
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
27-
.addClass(MockHttpMcpServer.class)
27+
.addClasses(AbstractMockHttpMcpServer.class, MockHttpMcpServer.class)
2828
.addAsResource(new StringAsset("""
2929
quarkus.langchain4j.mcp.client1.transport-type=http
3030
quarkus.langchain4j.mcp.client1.url=http://localhost:8081/mock-mcp/sse
@@ -50,7 +50,7 @@ public void test() {
5050

5151
ToolProvider provider = toolProviderCDIInstance.get();
5252
assertThat(provider).isNotNull();
53-
assertThat(ClientProxy.unwrap(provider)).isInstanceOf(McpToolProvider.class);
53+
assertThat(ClientProxy.unwrap(provider)).isInstanceOf(QuarkusMcpToolProvider.class);
5454
}
5555

5656
}

0 commit comments

Comments
 (0)