Skip to content

feat(llmobs): distributed tracing for mcp #14045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ ddtrace/contrib/internal/crewai @DataDog/ml-observ
ddtrace/contrib/internal/openai_agents @DataDog/ml-observability
ddtrace/contrib/internal/litellm @DataDog/ml-observability
ddtrace/contrib/internal/pydantic_ai @DataDog/ml-observability
ddtrace/contrib/internal/mcp @DataDog/ml-observability
tests/llmobs @DataDog/ml-observability
tests/contrib/openai @DataDog/ml-observability
tests/contrib/langchain @DataDog/ml-observability
Expand All @@ -170,6 +171,7 @@ tests/contrib/crewai @DataDog/ml-observ
tests/contrib/openai_agents @DataDog/ml-observability
tests/contrib/litellm @DataDog/ml-observability
tests/contrib/pydantic_ai @DataDog/ml-observability
tests/contrib/mcp @DataDog/ml-observability
.gitlab/tests/llmobs.yml @DataDog/ml-observability
# MLObs snapshot tests
tests/snapshots/tests.contrib.anthropic.* @DataDog/ml-observability
Expand Down
6 changes: 6 additions & 0 deletions ddtrace/contrib/internal/mcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
variables.
Default: ``DD_SERVICE``

.. py:data:: ddtrace.config.mcp["distributed_tracing"]
Whether or not to enable distributed tracing for MCP requests.
Alternatively, you can set this option with the ``DD_MCP_DISTRIBUTED_TRACING`` environment
variable.
Default: ``True``

Instance Configuration
~~~~~~~~~~~~~~~~~~~~~~

Expand Down
88 changes: 87 additions & 1 deletion ddtrace/contrib/internal/mcp/patch.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,34 @@
import os
import sys
from typing import Any
from typing import Dict
from typing import Optional

import mcp

from ddtrace import config
from ddtrace.contrib.internal.trace_utils import activate_distributed_headers
from ddtrace.contrib.trace_utils import unwrap
from ddtrace.contrib.trace_utils import with_traced_module
from ddtrace.contrib.trace_utils import wrap
from ddtrace.internal.logger import get_logger
from ddtrace.internal.utils.formats import asbool
from ddtrace.llmobs._integrations.mcp import CLIENT_TOOL_CALL_OPERATION_NAME
from ddtrace.llmobs._integrations.mcp import SERVER_TOOL_CALL_OPERATION_NAME
from ddtrace.llmobs._integrations.mcp import MCPIntegration
from ddtrace.llmobs._utils import _get_attr
from ddtrace.propagation.http import HTTPPropagator
from ddtrace.trace import Pin


config._add("mcp", {})
log = get_logger(__name__)

config._add(
"mcp",
{
"distributed_tracing": asbool(os.getenv("DD_MCP_DISTRIBUTED_TRACING", default=True)),
},
)


def get_version() -> str:
Expand All @@ -26,6 +41,71 @@ def _supported_versions() -> Dict[str, str]:
return {"mcp": ">=1.10.0"}


def _set_distributed_headers_into_mcp_request(pin, request):
"""Inject distributed tracing headers into MCP request metadata."""
span = pin.tracer.current_span()
if span is None:
return request

headers = {}
HTTPPropagator.inject(span.context, headers)
if not headers:
return request
if _get_attr(request, "root", None) is None:
return request

try:
request_params = _get_attr(request.root, "params", None)
if not request_params:
return request

# Use the `_meta` field to store tracing headers. It is accessed via a public
# `meta` attribute on the request params. This field is reserved for server/clients
# to attach additional metadata to a request. For more information, see:
# https://modelcontextprotocol.io/specification/2025-06-18/basic#meta
existing_meta = _get_attr(request_params, "meta", None)
meta_dict = existing_meta.model_dump() if existing_meta else {}

meta_dict["_dd_trace_context"] = headers
params_dict = request_params.model_dump(by_alias=True)
params_dict["_meta"] = meta_dict

new_params = type(request_params)(**params_dict)
request_dict = request.root.model_dump()
request_dict["params"] = new_params

new_request_root = type(request.root)(**request_dict)
return type(request)(new_request_root)
except Exception:
log.error("Error injecting distributed tracing headers into MCP request metadata", exc_info=True)
return request


def _extract_distributed_headers_from_mcp_request(kwargs: Dict[str, Any]) -> Optional[Dict[str, str]]:
if "context" not in kwargs:
return
context = kwargs.get("context")
if not context or not _get_attr(context, "request_context", None):
return
request_context = _get_attr(context, "request_context", None)
meta = _get_attr(request_context, "meta", None)
if not meta:
return
headers = _get_attr(meta, "_dd_trace_context", None)
if headers:
return headers


@with_traced_module
def traced_send_request(mcp, pin, func, instance, args, kwargs):
"""Injects distributed tracing headers into MCP request metadata"""
if not args or not config.mcp.distributed_tracing:
return func(*args, **kwargs)
request = args[0]
modified_request = _set_distributed_headers_into_mcp_request(pin, request)
return func(*((modified_request,) + args[1:]), **kwargs)


@with_traced_module
async def traced_call_tool(mcp, pin, func, instance, args, kwargs):
integration = mcp._datadog_integration
Expand All @@ -51,6 +131,8 @@ async def traced_call_tool(mcp, pin, func, instance, args, kwargs):
@with_traced_module
async def traced_tool_manager_call_tool(mcp, pin, func, instance, args, kwargs):
integration = mcp._datadog_integration
if config.mcp.distributed_tracing:
activate_distributed_headers(pin.tracer, config.mcp, _extract_distributed_headers_from_mcp_request(kwargs))

span = integration.trace(pin, SERVER_TOOL_CALL_OPERATION_NAME, submit_to_llmobs=True)

Expand Down Expand Up @@ -80,7 +162,9 @@ def patch():

from mcp.client.session import ClientSession
from mcp.server.fastmcp.tools.tool_manager import ToolManager
from mcp.shared.session import BaseSession

wrap(BaseSession, "send_request", traced_send_request(mcp))
wrap(ClientSession, "call_tool", traced_call_tool(mcp))
wrap(ToolManager, "call_tool", traced_tool_manager_call_tool(mcp))

Expand All @@ -93,7 +177,9 @@ def unpatch():

from mcp.client.session import ClientSession
from mcp.server.fastmcp.tools.tool_manager import ToolManager
from mcp.shared.session import BaseSession

unwrap(BaseSession, "send_request")
unwrap(ClientSession, "call_tool")
unwrap(ToolManager, "call_tool")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
LLM Observability, mcp: Adds distributed tracing support for MCP tool calls across client-server boundaries by default.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LLM Observability, mcp: Adds distributed tracing support for MCP tool calls across client-server boundaries by default.
mcp: Adds distributed tracing support for MCP tool calls across client-server boundaries by default, for both APM and LLMObs traces.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the past for LLM Obs bug fixes/features that were integration specific i used LLM Observability: -- should that still apply here? i always viewed <integration-name>: to mean apm specific changes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make it two release notes then to be clear, but imo this is a clear enough release note

To disable distributed tracing for mcp, set the configuration: `DD_MCP_DISTRIBUTED_TRACING=False` for both the client and server.
68 changes: 68 additions & 0 deletions tests/contrib/mcp/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from http.server import BaseHTTPRequestHandler
from http.server import HTTPServer
import json
import threading
import time

from mcp.server.fastmcp import FastMCP
from mcp.shared.memory import create_connected_server_and_client_session
import pytest
Expand All @@ -12,6 +18,27 @@
from tests.utils import override_global_config


class LLMObsServer(BaseHTTPRequestHandler):
"""A mock server for the LLMObs backend used to capture the requests made by the client.

Python's HTTPRequestHandler is a bit weird and uses a class rather than an instance
for running an HTTP server so the requests are stored in a class variable and reset in the pytest fixture.
"""

requests = []

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def do_POST(self) -> None:
content_length = int(self.headers["Content-Length"])
body = self.rfile.read(content_length).decode("utf-8")
self.requests.append({"path": self.path, "headers": dict(self.headers), "body": body})
self.send_response(200)
self.end_headers()
self.wfile.write(b"OK")


@pytest.fixture(autouse=True)
def mcp_setup():
patch()
Expand Down Expand Up @@ -114,3 +141,44 @@ async def mcp_client(mcp_server):
async with create_connected_server_and_client_session(mcp_server._mcp_server) as client:
await client.initialize()
yield client


@pytest.fixture
def _llmobs_backend():
LLMObsServer.requests = []
# Create and start the HTTP server
server = HTTPServer(("localhost", 0), LLMObsServer)
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()

# Provide the server details to the test
server_address = f"http://{server.server_address[0]}:{server.server_address[1]}"

yield server_address, LLMObsServer.requests

# Stop the server after the test
server.shutdown()
server.server_close()


@pytest.fixture
def llmobs_backend(_llmobs_backend):
import pprint

_url, reqs = _llmobs_backend

class _LLMObsBackend:
def url(self):
return _url

def wait_for_num_events(self, num, attempts=1000):
for _ in range(attempts):
if len(reqs) == num:
return [json.loads(r["body"]) for r in reqs]
# time.sleep will yield the GIL so the server can process the request
time.sleep(0.001)
else:
raise TimeoutError(f"Expected {num} events, got {len(reqs)}: {pprint.pprint(reqs)}")

return _LLMObsBackend()
73 changes: 67 additions & 6 deletions tests/contrib/mcp/test_mcp_llmobs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import asyncio
import json
import os
from textwrap import dedent

import mock

from tests.llmobs._utils import _expected_llmobs_non_llm_span_event


def _get_client_and_server_spans_and_events(mock_tracer, llmobs_events):
"""Get client and server spans and events for testing."""
def _assert_distributed_trace(mock_tracer, llmobs_events, expected_tool_name):
"""Assert that client and server spans have the same trace ID and return client/server spans and LLM Obs events."""
traces = mock_tracer.pop_traces()
assert len(traces) >= 1

Expand All @@ -19,6 +21,9 @@ def _get_client_and_server_spans_and_events(mock_tracer, llmobs_events):

assert len(client_spans) >= 1 and len(server_spans) >= 1
assert len(client_events) >= 1 and len(server_events) >= 1
assert client_spans[0].trace_id == server_spans[0].trace_id
assert client_events[0]["trace_id"] == server_events[0]["trace_id"]
assert client_events[0]["_dd"]["apm_trace_id"] == server_events[0]["_dd"]["apm_trace_id"]

return all_spans, client_events, server_events, client_spans, server_spans

Expand All @@ -27,8 +32,8 @@ def test_llmobs_mcp_client_calls_server(mcp_setup, mock_tracer, llmobs_events, m
"""Test that LLMObs records are emitted for both client and server MCP operations."""
asyncio.run(mcp_call_tool("calculator", {"operation": "add", "a": 20, "b": 22}))

all_spans, client_events, server_events, client_spans, server_spans = _get_client_and_server_spans_and_events(
mock_tracer, llmobs_events
all_spans, client_events, server_events, client_spans, server_spans = _assert_distributed_trace(
mock_tracer, llmobs_events, "calculator"
)

assert len(all_spans) == 2
Expand Down Expand Up @@ -63,8 +68,8 @@ def test_llmobs_client_server_tool_error(mcp_setup, mock_tracer, llmobs_events,
"""Test error handling in both client and server MCP operations."""
asyncio.run(mcp_call_tool("failing_tool", {"param": "value"}))

all_spans, client_events, server_events, client_spans, server_spans = _get_client_and_server_spans_and_events(
mock_tracer, llmobs_events
all_spans, client_events, server_events, client_spans, server_spans = _assert_distributed_trace(
mock_tracer, llmobs_events, "failing_tool"
)

assert len(all_spans) == 2
Expand Down Expand Up @@ -105,3 +110,59 @@ def test_llmobs_client_server_tool_error(mcp_setup, mock_tracer, llmobs_events,
error_message="Error executing tool failing_tool: Tool execution failed",
error_stack=mock.ANY,
)


def test_mcp_distributed_tracing_disabled_env(ddtrace_run_python_code_in_subprocess, llmobs_backend):
"""Test that distributed tracing is disabled when DD_MCP_DISTRIBUTED_TRACING=false."""
env = os.environ.copy()
env["DD_LLMOBS_ML_APP"] = "test-ml-app"
env["DD_API_KEY"] = "test-api-key"
env["DD_LLMOBS_ENABLED"] = "1"
env["DD_LLMOBS_AGENTLESS_ENABLED"] = "0"
env["DD_TRACE_AGENT_URL"] = llmobs_backend.url()
env["DD_MCP_DISTRIBUTED_TRACING"] = "false"
out, err, status, _ = ddtrace_run_python_code_in_subprocess(
dedent(
"""
import asyncio
import logging
import warnings

logging.getLogger("mcp.server.lowlevel.server").setLevel(logging.WARNING)
warnings.filterwarnings("ignore", message="OpenTelemetry configuration.*not supported by Datadog")

from ddtrace.llmobs import LLMObs
LLMObs.enable()

from mcp.server.fastmcp import FastMCP
from mcp.shared.memory import create_connected_server_and_client_session

mcp = FastMCP(name="TestServer")

@mcp.tool(description="Get weather for a location")
def get_weather(location: str) -> str:
return f"Weather in {location} is 72°F"

async def test():
async with create_connected_server_and_client_session(mcp._mcp_server) as client:
await client.initialize()
await client.call_tool("get_weather", {"location": "San Francisco"})

asyncio.run(test())
"""
),
env=env,
)
assert out == b""
assert status == 0, err
events = llmobs_backend.wait_for_num_events(num=1)
traces = events[0]
assert len(traces) == 2

client_trace = next((t for t in traces if "Client Tool Call" in t["spans"][0]["name"]), None)
server_trace = next((t for t in traces if "Server Tool Execute" in t["spans"][0]["name"]), None)

assert client_trace is not None
assert server_trace is not None
assert client_trace["spans"][0]["trace_id"] != server_trace["spans"][0]["trace_id"]
assert client_trace["spans"][0]["_dd"]["apm_trace_id"] != server_trace["spans"][0]["_dd"]["apm_trace_id"]
6 changes: 6 additions & 0 deletions tests/contrib/mcp/test_mcp_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,26 @@ class TestMCPPatch(PatchTestCase.Base):
def assert_module_patched(self, mcp):
from mcp.client.session import ClientSession
from mcp.server.fastmcp.tools.tool_manager import ToolManager
from mcp.shared.session import BaseSession

self.assert_wrapped(BaseSession.send_request)
self.assert_wrapped(ClientSession.call_tool)
self.assert_wrapped(ToolManager.call_tool)

def assert_not_module_patched(self, mcp):
from mcp.client.session import ClientSession
from mcp.server.fastmcp.tools.tool_manager import ToolManager
from mcp.shared.session import BaseSession

self.assert_not_wrapped(BaseSession.send_request)
self.assert_not_wrapped(ClientSession.call_tool)
self.assert_not_wrapped(ToolManager.call_tool)

def assert_not_module_double_patched(self, mcp):
from mcp.client.session import ClientSession
from mcp.server.fastmcp.tools.tool_manager import ToolManager
from mcp.shared.session import BaseSession

self.assert_not_double_wrapped(BaseSession.send_request)
self.assert_not_double_wrapped(ClientSession.call_tool)
self.assert_not_double_wrapped(ToolManager.call_tool)
Loading
Loading