Skip to content

Commit 882fc50

Browse files
committed
Add a way to validate that MCP tool descriptions
This is done by utilizing an LLM to detect whether the tool description is malicious and could lead to a Tool Poisoning Attack (TPA)
1 parent 4f6440f commit 882fc50

File tree

3 files changed

+35
-26
lines changed

3 files changed

+35
-26
lines changed

mcp/deployment/src/main/java/io/quarkiverse/langchain4j/mcp/deployment/McpProcessor.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@
1010
import org.jboss.jandex.AnnotationInstance;
1111
import org.jboss.jandex.ClassType;
1212
import org.jboss.jandex.DotName;
13+
import org.jboss.jandex.ParameterizedType;
14+
import org.jboss.jandex.Type;
1315

1416
import dev.langchain4j.mcp.client.McpClient;
17+
import dev.langchain4j.model.chat.ChatLanguageModel;
1518
import dev.langchain4j.service.tool.ToolProvider;
19+
import io.quarkiverse.langchain4j.deployment.DotNames;
1620
import io.quarkiverse.langchain4j.mcp.runtime.McpClientName;
1721
import io.quarkiverse.langchain4j.mcp.runtime.McpRecorder;
1822
import io.quarkiverse.langchain4j.mcp.runtime.config.McpBuildTimeConfiguration;
@@ -48,12 +52,14 @@ public void registerMcpClients(McpBuildTimeConfiguration mcpBuildTimeConfigurati
4852
beanProducer.produce(SyntheticBeanBuildItem
4953
.configure(MCP_CLIENT)
5054
.addQualifier(qualifier)
55+
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
56+
new Type[] { ClassType.create(ChatLanguageModel.class) }, null))
5157
.setRuntimeInit()
5258
.defaultBean()
5359
.unremovable()
5460
// TODO: should we allow other scopes?
5561
.scope(ApplicationScoped.class)
56-
.supplier(
62+
.createWith(
5763
recorder.mcpClientSupplier(client.getKey(), mcpBuildTimeConfiguration, mcpRuntimeConfiguration))
5864
.done());
5965
}

mcp/runtime/src/main/java/io/quarkiverse/langchain4j/mcp/runtime/McpRecorder.java

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import java.util.List;
55
import java.util.Set;
66
import java.util.function.Function;
7-
import java.util.function.Supplier;
87

98
import dev.langchain4j.mcp.McpToolProvider;
109
import dev.langchain4j.mcp.client.DefaultMcpClient;
@@ -24,42 +23,40 @@
2423
@Recorder
2524
public class McpRecorder {
2625

27-
public Supplier<McpClient> mcpClientSupplier(String key, McpBuildTimeConfiguration buildTimeConfiguration,
26+
public Function<SyntheticCreationalContext<McpClient>, McpClient> mcpClientSupplier(String clientName,
27+
McpBuildTimeConfiguration buildTimeConfiguration,
2828
McpRuntimeConfiguration mcpRuntimeConfiguration) {
29-
return new Supplier<McpClient>() {
29+
return new Function<>() {
3030
@Override
31-
public McpClient get() {
32-
McpTransport transport = null;
33-
McpClientBuildTimeConfig buildTimeConfig = buildTimeConfiguration.clients().get(key);
34-
McpClientRuntimeConfig runtimeConfig = mcpRuntimeConfiguration.clients().get(key);
35-
switch (buildTimeConfig.transportType()) {
36-
case STDIO:
31+
public McpClient apply(SyntheticCreationalContext<McpClient> context) {
32+
McpTransport transport;
33+
McpClientBuildTimeConfig buildTimeConfig = buildTimeConfiguration.clients().get(clientName);
34+
McpClientRuntimeConfig runtimeConfig = mcpRuntimeConfiguration.clients().get(clientName);
35+
transport = switch (buildTimeConfig.transportType()) {
36+
case STDIO -> {
3737
List<String> command = runtimeConfig.command().orElseThrow(() -> new ConfigurationException(
38-
"MCP client configuration named " + key + " is missing the 'command' property"));
39-
transport = new StdioMcpTransport.Builder()
38+
"MCP client configuration named " + clientName + " is missing the 'command' property"));
39+
yield new StdioMcpTransport.Builder()
4040
.command(command)
4141
.logEvents(runtimeConfig.logResponses().orElse(false))
4242
.environment(runtimeConfig.environment())
4343
.build();
44-
break;
45-
case HTTP:
46-
transport = new QuarkusHttpMcpTransport.Builder()
47-
.sseUrl(runtimeConfig.url().orElseThrow(() -> new ConfigurationException(
48-
"MCP client configuration named " + key + " is missing the 'url' property")))
49-
.logRequests(runtimeConfig.logRequests().orElse(false))
50-
.logResponses(runtimeConfig.logResponses().orElse(false))
51-
.build();
52-
break;
53-
default:
54-
throw new IllegalArgumentException("Unknown transport type: " + buildTimeConfig.transportType());
55-
}
56-
return new DefaultMcpClient.Builder()
44+
}
45+
case HTTP -> new QuarkusHttpMcpTransport.Builder()
46+
.sseUrl(runtimeConfig.url().orElseThrow(() -> new ConfigurationException(
47+
"MCP client configuration named " + clientName + " is missing the 'url' property")))
48+
.logRequests(runtimeConfig.logRequests().orElse(false))
49+
.logResponses(runtimeConfig.logResponses().orElse(false))
50+
.build();
51+
};
52+
var result = new DefaultMcpClient.Builder()
5753
.transport(transport)
5854
.toolExecutionTimeout(runtimeConfig.toolExecutionTimeout())
5955
.resourcesTimeout(runtimeConfig.resourcesTimeout())
6056
// TODO: it should be possible to choose a log handler class via configuration
61-
.logHandler(new QuarkusDefaultMcpLogHandler(key))
57+
.logHandler(new QuarkusDefaultMcpLogHandler(clientName))
6258
.build();
59+
return result;
6360
}
6461
};
6562
}

mcp/runtime/src/main/java/io/quarkiverse/langchain4j/mcp/runtime/config/McpClientRuntimeConfig.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,10 @@ public interface McpClientRuntimeConfig {
5858
@WithDefault("60s")
5959
Duration resourcesTimeout();
6060

61+
/**
62+
* The named model to use in order to judge whether the descriptions of the tools provided by the MCP server
63+
* are malicious. If they are, a warning will be printed and the tool will never be used.
64+
*/
65+
Optional<String> guardrailModelName();
66+
6167
}

0 commit comments

Comments
 (0)