Skip to content

Add fine-grained control for MCP server(s) provided tools #1434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
toolQualifierProviders.stream().map(
ToolQualifierProvider.BuildItem::getProvider).toList());

List<String> methodMcpClientNames = gatherMethodMcpClientNames(method);

List<String> outputGuardrails = AiServicesMethodBuildItem.gatherGuardrails(method, OUTPUT_GUARDRAILS);
List<String> inputGuardrails = AiServicesMethodBuildItem.gatherGuardrails(method, INPUT_GUARDRAILS);

Expand All @@ -1390,8 +1392,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
userMessageInfo, memoryIdParamPosition, requiresModeration,
returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)),
overrideChatModelParamPosition, metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo,
methodToolClassInfo, switchToWorkerThreadForToolExecution, inputGuardrails, outputGuardrails,
accumulatorClassName, responseAugmenterClassName);
methodToolClassInfo, methodMcpClientNames, switchToWorkerThreadForToolExecution, inputGuardrails,
outputGuardrails, accumulatorClassName, responseAugmenterClassName);
}

private Optional<JsonSchema> jsonSchemaFrom(java.lang.reflect.Type returnType) {
Expand Down Expand Up @@ -1839,6 +1841,26 @@ private List<String> gatherMethodToolClassNames(MethodInfo method) {
return Arrays.stream(toolClasses).map(t -> t.name().toString()).collect(Collectors.toList());
}

private List<String> gatherMethodMcpClientNames(MethodInfo method) {
// Using the class name to keep the McpToolBox annotation in the mcp module
AnnotationInstance mcpToolBoxInstance = method.declaredAnnotation("io.quarkiverse.langchain4j.mcp.runtime.McpToolBox");
if (mcpToolBoxInstance == null) {
return null;
}

AnnotationValue mcpToolBoxValue = mcpToolBoxInstance.value();
if (mcpToolBoxValue == null) {
return Collections.emptyList();
}

String[] mcpClientNames = mcpToolBoxValue.asStringArray();
if (mcpClientNames.length == 0) {
return Collections.emptyList();
}

return Arrays.asList(mcpClientNames);
}

private DotName determineChatMemorySeeder(ClassInfo iface, ClassOutput classOutput) {
List<AnnotationInstance> annotations = iface.annotations(SEED_MEMORY);
if (annotations.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public final class AiServiceMethodCreateInfo {
private final Optional<SpanInfo> spanInfo;
// support @Toolbox
private final Map<String, AnnotationLiteral<?>> toolClassInfo;
private final List<String> mcpClientNames;
private final ResponseSchemaInfo responseSchemaInfo;

// support for guardrails
Expand Down Expand Up @@ -78,6 +79,7 @@ public AiServiceMethodCreateInfo(String interfaceName, String methodName,
Optional<SpanInfo> spanInfo,
ResponseSchemaInfo responseSchemaInfo,
Map<String, AnnotationLiteral<?>> toolClassInfo,
List<String> mcpClientNames,
boolean switchToWorkerThreadForToolExecution,
List<String> inputGuardrailsClassNames,
List<String> outputGuardrailsClassNames,
Expand All @@ -102,6 +104,7 @@ public Type get() {
this.spanInfo = spanInfo;
this.responseSchemaInfo = responseSchemaInfo;
this.toolClassInfo = toolClassInfo;
this.mcpClientNames = mcpClientNames;
this.inputGuardrailsClassNames = inputGuardrailsClassNames;
this.outputGuardrailsClassNames = outputGuardrailsClassNames;
this.outputTokenAccumulatorClassName = outputTokenAccumulatorClassName;
Expand Down Expand Up @@ -173,6 +176,10 @@ public Map<String, AnnotationLiteral<?>> getToolClassInfo() {
return toolClassInfo;
}

public List<String> getMcpClientNames() {
return mcpClientNames;
}

public List<ToolSpecification> getToolSpecifications() {
return toolSpecifications;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob
if (context.toolService.toolProvider() != null) {
toolSpecifications = toolSpecifications != null ? new ArrayList<>(toolSpecifications) : new ArrayList<>();
toolExecutors = toolExecutors != null ? new HashMap<>(toolExecutors) : new HashMap<>();
ToolProviderRequest request = new ToolProviderRequest(memoryId, userMessage);
ToolProviderRequest request = new QuarkusToolProviderRequest(memoryId, userMessage,
methodCreateInfo.getMcpClientNames());
ToolProviderResult result = context.toolService.toolProvider().provideTools(request);
for (ToolSpecification specification : result.tools().keySet()) {
toolSpecifications.add(specification);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package io.quarkiverse.langchain4j.runtime.aiservice;

import java.util.List;

import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.service.tool.ToolProviderRequest;

public class QuarkusToolProviderRequest extends ToolProviderRequest {

private final List<String> mcpClientNames;

public QuarkusToolProviderRequest(Object chatMemoryId, UserMessage userMessage, List<String> mcpClientNames) {
super(chatMemoryId, userMessage);
this.mcpClientNames = mcpClientNames;
}

public List<String> getMcpClientNames() {
return mcpClientNames;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
package io.quarkiverse.langchain4j.mcp.test;

import java.util.Map;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;

import jakarta.inject.Inject;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.sse.Sse;
import jakarta.ws.rs.sse.SseEventSink;

import org.jboss.logging.Logger;
import org.jboss.resteasy.reactive.RestStreamElementType;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;

public abstract class AbstractMockHttpMcpServer {

private final AtomicLong ID_GENERATOR = new AtomicLong(new Random().nextLong(1000, 5000));

private static Logger logger = Logger.getLogger(MockHttpMcpServer.class);

private volatile boolean shouldRespondToPing = true;

// key = operation ID of the ping
// value = future that will be completed when the ping response for that ID is received
final Map<Long, CompletableFuture<Void>> pendingPings = new ConcurrentHashMap<>();

private volatile SseEventSink sink;
private volatile Sse sse;
private final ObjectMapper objectMapper = new ObjectMapper();
private volatile boolean initializationNotificationReceived = false;

@Inject
ScheduledExecutorService scheduledExecutorService;

@Path("/sse")
@Produces(MediaType.SERVER_SENT_EVENTS)
@RestStreamElementType(MediaType.TEXT_PLAIN)
public void sse(@Context SseEventSink sink, @Context Sse sse) {
this.sink = sink;
this.sse = sse;
sink.send(sse.newEventBuilder()
.id("id")
.name("endpoint")
.mediaType(MediaType.TEXT_PLAIN_TYPE)
.data("/" + getEndpoint() + "/post")
.build());
}

protected abstract String getEndpoint();

@Path("/post")
@Consumes(MediaType.APPLICATION_JSON)
@POST
public Response post(JsonNode message) {
if (message.get("method") != null) {
String method = message.get("method").asText();
if (method.equals("notifications/cancelled")) {
return Response.ok().build();
}
if (method.equals("notifications/initialized")) {
if (initializationNotificationReceived) {
return Response.serverError().entity("Duplicate 'notifications/initialized' message").build();
}
initializationNotificationReceived = true;
return Response.ok().build();
}
String operationId = message.get("id").asText();
if (method.equals("initialize")) {
initialize(operationId);
} else if (method.equals("tools/list")) {
ensureInitialized();
listTools(operationId);
} else if (method.equals("tools/call")) {
ensureInitialized();
if (message.get("params").get("name").asText().equals("add")) {
executeAddOperation(message, operationId);
} else if (message.get("params").get("name").asText().equals("logging")) {
executeLoggingOperation(message, operationId);
} else if (message.get("params").get("name").asText().equals("longRunningOperation")) {
executeLongRunningOperation(message, operationId);
} else {
return Response.serverError().entity("Unknown operation").build();
}

} else if (method.equals("ping")) {
if (shouldRespondToPing) {
ObjectNode result = buildPongMessage(operationId);
sink.send(sse.newEventBuilder()
.name("message")
.data(result)
.build());
} else {
logger.info("Ignoring ping request");
}
return Response.accepted().build();
}
} else {
// if 'method' is null, the message is probably a ping response
long id = message.get("id").asLong();
CompletableFuture<Void> future = pendingPings.remove(id);
if (future != null) {
future.complete(null);
} else {
return Response.serverError().entity("Received a ping response with unknown ID " + id).build();
}
}
return Response.accepted().build();
}

private ObjectNode buildPongMessage(String operationId) {
ObjectNode pong = objectMapper.createObjectNode();
pong.put("jsonrpc", "2.0");
pong.put("id", operationId);
pong.put("result", objectMapper.createObjectNode());
return pong;
}

private void executeLoggingOperation(JsonNode message, String operationId) {
ObjectNode logData = objectMapper.createObjectNode();
logData.put("message", "This is a log message");
ObjectNode log = buildLoggingMessage(logData);
sink.send(sse.newEventBuilder()
.name("message")
.data(log)
.build());
ObjectNode result = buildToolResult(operationId, "OK");
sink.send(sse.newEventBuilder()
.name("message")
.data(result)
.build());
}

private ObjectNode buildLoggingMessage(JsonNode message) {
ObjectNode log = objectMapper.createObjectNode();
log.put("jsonrpc", "2.0");
log.put("method", "notifications/message");
ObjectNode params = objectMapper.createObjectNode();
log.set("params", params);
params.put("level", "info");
params.put("logger", getEndpoint());
params.set("data", message);
return log;
}

private ObjectNode buildToolResult(String operationId, String result) {
ObjectNode resultNode = objectMapper.createObjectNode();
resultNode.put("id", operationId);
resultNode.put("jsonrpc", "2.0");
ObjectNode resultContent = objectMapper.createObjectNode();
resultNode.set("result", resultContent);
resultContent.putArray("content")
.addObject()
.put("type", "text")
.put("text", result);
return resultNode;
}

// throw an exception if we haven't received the 'notifications/initialized' message yet
private void ensureInitialized() {
if (!initializationNotificationReceived) {
throw new IllegalStateException("The client has not sent the 'notifications/initialized' message yet");
}
}

private void listTools(String operationId) {
String response = getToolsListResponse().formatted(operationId);
sink.send(sse.newEventBuilder()
.name("message")
.data(response)
.build());
}

protected abstract String getToolsListResponse();

private void initialize(String operationId) {
ObjectNode initializeResponse = objectMapper.createObjectNode();
initializeResponse
.put("id", operationId)
.put("jsonrpc", "2.0")
.putObject("result")
.put("protocolVersion", "2024-11-05");
sink.send(sse.newEventBuilder()
.name("message")
.data(initializeResponse)
.build());
}

private void executeAddOperation(JsonNode message, String operationId) {
int a = message.get("params").get("arguments").get("a").asInt();
int b = message.get("params").get("arguments").get("b").asInt();
int additionResult = a + b;
ObjectNode result = buildToolResult(operationId, "The sum of " + a + " and " + b + " is " + additionResult + ".");
sink.send(sse.newEventBuilder()
.name("message")
.data(result)
.build());
}

private void executeLongRunningOperation(JsonNode message, String operationId) {
int duration = message.get("params").get("arguments").get("duration").asInt();
scheduledExecutorService.schedule(() -> {
ObjectNode result = buildToolResult(operationId, "Operation completed.");
sink.send(sse.newEventBuilder()
.name("message")
.data(result)
.build());
}, duration, TimeUnit.SECONDS);
}

long sendPing() {
ObjectNode initializeResponse = objectMapper.createObjectNode();
long id = ID_GENERATOR.incrementAndGet();
initializeResponse
.put("id", id)
.put("jsonrpc", "2.0")
.put("method", "ping");
sink.send(sse.newEventBuilder()
.name("message")
.data(initializeResponse)
.build());
pendingPings.put(id, new CompletableFuture<>());
return id;
}

void stopRespondingToPings() {
shouldRespondToPing = false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class McpClaudeConfigTest {
@RegisterExtension
static QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClass(MockHttpMcpServer.class)
.addClasses(AbstractMockHttpMcpServer.class, MockHttpMcpServer.class)
.addAsResource(new StringAsset("""
{
"mcpServers": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.mcp.McpToolProvider;
import dev.langchain4j.mcp.client.DefaultMcpClient;
import dev.langchain4j.mcp.client.McpClient;
import dev.langchain4j.service.tool.ToolProvider;
import io.quarkiverse.langchain4j.mcp.runtime.McpClientName;
import io.quarkiverse.langchain4j.mcp.runtime.QuarkusMcpToolProvider;
import io.quarkus.arc.ClientProxy;
import io.quarkus.test.QuarkusUnitTest;

Expand All @@ -24,7 +24,7 @@ public class McpClientAndToolProviderCDITest {
@RegisterExtension
static QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClass(MockHttpMcpServer.class)
.addClasses(AbstractMockHttpMcpServer.class, MockHttpMcpServer.class)
.addAsResource(new StringAsset("""
quarkus.langchain4j.mcp.client1.transport-type=http
quarkus.langchain4j.mcp.client1.url=http://localhost:8081/mock-mcp/sse
Expand All @@ -50,7 +50,7 @@ public void test() {

ToolProvider provider = toolProviderCDIInstance.get();
assertThat(provider).isNotNull();
assertThat(ClientProxy.unwrap(provider)).isInstanceOf(McpToolProvider.class);
assertThat(ClientProxy.unwrap(provider)).isInstanceOf(QuarkusMcpToolProvider.class);
}

}
Loading