Skip to content

feat: add ToolReturn for customizable tool return #2060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,66 @@ _(This example is complete, it can be run "as is")_

Some models (e.g. Gemini) natively support semi-structured return values, while some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON.

### Advanced Tool Returns

For scenarios where you need more control over both the tool's return value and the content sent to the model, you can use [`ToolReturn`][pydantic_ai.messages.ToolReturn]. This is particularly useful when you want to:

- Provide rich multi-modal content (images, documents, etc.) to the model as context
- Separate the programmatic return value from the model's context
- Include additional metadata that shouldn't be sent to the LLM

Here's an example of a computer automation tool that captures screenshots and provides visual feedback:

```python {title="advanced_tool_return.py" test="skip" lint="skip"}
import time
from pydantic_ai import Agent
from pydantic_ai.messages import ToolReturn, BinaryContent

agent = Agent('openai:gpt-4o')

@agent.tool_plain
def click_and_capture(x: int, y: int) -> ToolReturn:
"""Click at coordinates and show before/after screenshots."""
# Take screenshot before action
before_screenshot = capture_screen()

# Perform click operation
perform_click(x, y)
time.sleep(0.5) # Wait for UI to update

# Take screenshot after action
after_screenshot = capture_screen()

return ToolReturn(
return_value=f"Successfully clicked at ({x}, {y})",
content=[
f"Clicked at coordinates ({x}, {y}). Here's the comparison:",
"Before:",
BinaryContent(data=before_screenshot, media_type="image/png"),
"After:",
BinaryContent(data=after_screenshot, media_type="image/png"),
"Please analyze the changes and suggest next steps."
],
metadata={
"coordinates": {"x": x, "y": y},
"action_type": "click_and_capture",
"timestamp": time.time()
}
)

# The model receives the rich visual content for analysis
# while your application can access the structured return_value and metadata
result = agent.run_sync("Click on the submit button and tell me what happened")
print(result.output)
# The model can analyze the screenshots and provide detailed feedback
```

- **`return_value`**: The actual return value used in the tool response. This is what gets serialized and sent back to the model as the tool's result.
- **`content`**: A sequence of content (text, images, documents, etc.) that provides additional context to the model. This appears as a separate user message.
- **`metadata`**: Optional metadata that your application can access but is not sent to the LLM. Useful for logging, debugging, or additional processing.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add Some other AI frameworks call this feature "artifacts", in case someone searches for that.


This separation allows you to provide rich context to the model while maintaining clean, structured return values for your application logic.

## Function Tools vs. Structured Outputs

As the name suggests, function tools use the model's "tools" or "functions" API to let the model know what is available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call function tools while others end the run and produce a final output.
Expand Down
34 changes: 32 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,9 +743,32 @@ async def process_function_tools( # noqa C901
if isinstance(result, _messages.RetryPromptPart):
results_by_index[index] = result
elif isinstance(result, _messages.ToolReturnPart):
if isinstance(result.content, _messages.ToolReturn):
tool_return = result.content
result.content = tool_return.return_value
if (
isinstance(result.content, _messages.MultiModalContentTypes)
or isinstance(result.content, list)
and any(
isinstance(content, _messages.MultiModalContentTypes)
for content in result.content # type: ignore
)
):
raise exceptions.UserError(
f"{result.tool_name}'s `return_value` contains invalid nested MultiModalContentTypes objects. "
f'Please use `content` instead.'
)
result.metadata = tool_return.metadata
user_parts.append(
_messages.UserPromptPart(
content=list(tool_return.content),
timestamp=result.timestamp,
part_kind='user-prompt',
)
)
contents: list[Any]
single_content: bool
if isinstance(result.content, list):
if isinstance(result.content, list): # type: ignore
contents = result.content # type: ignore
single_content = False
else:
Expand All @@ -754,7 +777,13 @@ async def process_function_tools( # noqa C901

processed_contents: list[Any] = []
for content in contents:
if isinstance(content, _messages.MultiModalContentTypes):
if isinstance(content, _messages.ToolReturn):
raise exceptions.UserError(
f"{result.tool_name}'s return contains invalid nested ToolReturn objects. "
f'ToolReturn should be used directly.'
)
elif isinstance(content, _messages.MultiModalContentTypes):
# Handle direct multimodal content
if isinstance(content, _messages.BinaryContent):
identifier = multi_modal_content_identifier(content.data)
else:
Expand All @@ -769,6 +798,7 @@ async def process_function_tools( # noqa C901
)
processed_contents.append(f'See file {identifier}')
else:
# Handle regular content
processed_contents.append(content)

if single_content:
Expand Down
26 changes: 26 additions & 0 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,29 @@ def format(self) -> str:

UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent'


@dataclass(repr=False)
class ToolReturn:
"""A structured return value for tools that need to provide both a return value and custom content to the model.
This class allows tools to return complex responses that include:
- A return value for actual tool return
- Custom content (including multi-modal content) to be sent to the model as a UserPromptPart
- Optional metadata for application use
"""

return_value: Any
"""The return value to be used in the tool response."""

content: Sequence[UserContent]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this one optional, in case the user wants to use only a return_value and metadata.

"""The content sequence to be sent to the model as a UserPromptPart."""

metadata: Any = None
"""Additional data that can be accessed programmatically by the application but is not sent to the LLM."""

__repr__ = _utils.dataclasses_no_defaults_repr


# Ideally this would be a Union of types, but Python 3.9 requires it to be a string, and strings don't work with `isinstance``.
MultiModalContentTypes = (ImageUrl, AudioUrl, DocumentUrl, VideoUrl, BinaryContent)
_document_format_lookup: dict[str, DocumentFormat] = {
Expand Down Expand Up @@ -396,6 +419,9 @@ class ToolReturnPart:
tool_call_id: str
"""The tool call identifier, this is used by some models including OpenAI."""

metadata: Any = None
"""Additional data that can be accessed programmatically by the application but is not sent to the LLM."""

timestamp: datetime = field(default_factory=_now_utc)
"""The timestamp, when the tool returned."""

Expand Down
1 change: 1 addition & 0 deletions tests/models/test_model_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def test_var_args():
'tool_name': 'get_var_args',
'content': '{"args": [1, 2, 3]}',
'tool_call_id': IsStr(),
'metadata': None,
'timestamp': IsStr() & IsNow(iso_string=True, tz=timezone.utc), # type: ignore[reportUnknownMemberType]
'part_kind': 'tool-return',
}
Expand Down
153 changes: 153 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pydantic_ai.agent import AgentRunResult
from pydantic_ai.messages import (
BinaryContent,
ImageUrl,
ModelMessage,
ModelMessagesTypeAdapter,
ModelRequest,
Expand All @@ -33,6 +34,7 @@
SystemPromptPart,
TextPart,
ToolCallPart,
ToolReturn,
ToolReturnPart,
UserPromptPart,
)
Expand Down Expand Up @@ -3127,3 +3129,154 @@ def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:

with pytest.raises(UserError, match='Output tools are not supported by the model.'):
agent.run_sync('Hello')


def test_multimodal_tool_response():
"""Test ToolReturn with custom content and tool return."""

def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
if len(messages) == 1:
return ModelResponse(parts=[TextPart('Starting analysis'), ToolCallPart('analyze_data', {})])
else:
return ModelResponse(
parts=[
TextPart('Analysis completed'),
]
)

agent = Agent(FunctionModel(llm))

@agent.tool_plain
def analyze_data() -> ToolReturn:
return ToolReturn(
return_value='Data analysis completed successfully',
content=[
'Here are the analysis results:',
ImageUrl('https://example.com/chart.jpg'),
'The chart shows positive trends.',
],
metadata={'foo': 'bar'},
)

result = agent.run_sync('Please analyze the data')

# Verify final output
assert result.output == 'Analysis completed'

# Verify message history contains the expected parts

# Verify the complete message structure using snapshot
assert result.all_messages() == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='Please analyze the data', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[
TextPart(content='Starting analysis'),
ToolCallPart(
tool_name='analyze_data',
args={},
tool_call_id=IsStr(),
),
],
usage=Usage(requests=1, request_tokens=54, response_tokens=4, total_tokens=58),
model_name='function:llm:',
timestamp=IsNow(tz=timezone.utc),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='analyze_data',
content='Data analysis completed successfully',
tool_call_id=IsStr(),
metadata={'foo': 'bar'},
timestamp=IsNow(tz=timezone.utc),
),
UserPromptPart(
content=[
'Here are the analysis results:',
ImageUrl(url='https://example.com/chart.jpg'),
'The chart shows positive trends.',
],
timestamp=IsNow(tz=timezone.utc),
),
]
),
ModelResponse(
parts=[TextPart(content='Analysis completed')],
usage=Usage(requests=1, request_tokens=70, response_tokens=6, total_tokens=76),
model_name='function:llm:',
timestamp=IsNow(tz=timezone.utc),
),
]
)


def test_many_multimodal_tool_response():
"""Test ToolReturn with custom content and tool return."""

def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
if len(messages) == 1:
return ModelResponse(parts=[TextPart('Starting analysis'), ToolCallPart('analyze_data', {})])
else:
return ModelResponse( # pragma: no cover
parts=[
TextPart('Analysis completed'),
]
)

agent = Agent(FunctionModel(llm))

@agent.tool_plain
def analyze_data() -> list[Any]:
return [
ToolReturn(
return_value='Data analysis completed successfully',
content=[
'Here are the analysis results:',
ImageUrl('https://example.com/chart.jpg'),
'The chart shows positive trends.',
],
metadata={'foo': 'bar'},
),
'Something else',
]

with pytest.raises(
UserError,
match="analyze_data's return contains invalid nested ToolReturn objects. ToolReturn should be used directly.",
):
agent.run_sync('Please analyze the data')


def test_multimodal_tool_response_nested():
"""Test ToolReturn with custom content and tool return."""

def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
if len(messages) == 1:
return ModelResponse(parts=[TextPart('Starting analysis'), ToolCallPart('analyze_data', {})])
else:
return ModelResponse( # pragma: no cover
parts=[
TextPart('Analysis completed'),
]
)

agent = Agent(FunctionModel(llm))

@agent.tool_plain
def analyze_data() -> ToolReturn:
return ToolReturn(
return_value=ImageUrl('https://example.com/chart.jpg'),
content=[
'Here are the analysis results:',
ImageUrl('https://example.com/chart.jpg'),
'The chart shows positive trends.',
],
metadata={'foo': 'bar'},
)

with pytest.raises(
UserError,
match="analyze_data's `return_value` contains invalid nested MultiModalContentTypes objects. Please use `content` instead.",
):
agent.run_sync('Please analyze the data')