Skip to content

Retry prompt model response template #2008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ with capture_run_messages() as messages: # (2)!
)
],
usage=Usage(
requests=1, request_tokens=72, response_tokens=8, total_tokens=80
requests=1, request_tokens=74, response_tokens=8, total_tokens=82
),
model_name='gpt-4o',
timestamp=datetime.datetime(...),
Expand Down
9 changes: 8 additions & 1 deletion pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,9 @@ def otel_event(self, _settings: InstrumentationSettings) -> Event:
error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))


DEFAULT_MODEL_RESPONSE_TEMPLATE = 'Validator response:\n{description}\n\nFix the errors and try again.'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the default be what it currently is?

Suggested change
DEFAULT_MODEL_RESPONSE_TEMPLATE = 'Validator response:\n{description}\n\nFix the errors and try again.'
DEFAULT_MODEL_RESPONSE_TEMPLATE = '{description}\n\nFix the errors and try again.'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Kludex
I believe what @DouweM requested, was to "fix" in terms of changing the string and this PR is supposed to do that rather than focus on making this more flexible/configurable.
I created separate issue that is gonna be about that #2009

Also after sleeping on it, I noticed, that the string was supposed to be added only if tool_name is None so I gotta change this.
And extracting to constant outside of dataclass is not necessary, I initially did it thinking I will reference this in tests, but then I saw that it's not necessary because if inline-snapshot usage.

Perhaps something simple along these lines and leave all the flexibility code for the other issue.

        if self.tool_name:
            prefix = ""
        else:
            prefix = "Validator feedback:\n"
        return f'{prefix}{description}\n\nFix the errors and try again.'

What do you guys think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hovi For this PR I was expecting to add the prefix just if isinstance(self.content, str) (because the other branch already has the N validation errors: prefix) and self.tool_name is None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! Thanks @DouweM

    def model_response(self) -> str:
        """Return a string message describing why the retry is requested."""
        if isinstance(self.content, str):
            if self.tool_name is None:
                description = f"Validator feedback:\n{self.content}" 
            else:
                description = self.content
        else:
            json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
            description = f'{len(self.content)} validation errors: {json_errors.decode()}'
        return f'{description}\n\nFix the errors and try again.'

Is this what you meant?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hovi Yep that's pretty much what I was thinking, in case @Kludex agrees. I'd make it Validation feedback though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Kludex please take a look at this when you get a minute if you can 🙏

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hovi As in the example you shared above, can we please only include the new Validation feedback: prefix if there's not already the N validation errors: prefix?

I'd also prefer to NOT have an overridable template on RetryPromptPart for now, and consider the question of how to override hard-coded prompt strings separately.



@dataclass(repr=False)
class RetryPromptPart:
"""A message back to a model asking it to try again.
Expand Down Expand Up @@ -461,14 +464,18 @@ class RetryPromptPart:
part_kind: Literal['retry-prompt'] = 'retry-prompt'
"""Part type identifier, this is available on all parts as a discriminator."""

model_response_template: str = field(
default=DEFAULT_MODEL_RESPONSE_TEMPLATE,
)

def model_response(self) -> str:
"""Return a string message describing why the retry is requested."""
if isinstance(self.content, str):
description = self.content
else:
json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
description = f'{len(self.content)} validation errors: {json_errors.decode()}'
return f'{description}\n\nFix the errors and try again.'
return self.model_response_template.format(description=description)

def otel_event(self, _settings: InstrumentationSettings) -> Event:
if self.tool_name is None:
Expand Down
4 changes: 4 additions & 0 deletions tests/models/test_instrumented.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ async def test_instrumented_model(capfire: CaptureLogfire):
{
'body': {
'content': """\
Validator response:
retry_prompt1

Fix the errors and try again.\
Expand All @@ -238,6 +239,7 @@ async def test_instrumented_model(capfire: CaptureLogfire):
{
'body': {
'content': """\
Validator response:
retry_prompt2

Fix the errors and try again.\
Expand Down Expand Up @@ -596,6 +598,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire):
{
'event.name': 'gen_ai.tool.message',
'content': """\
Validator response:
retry_prompt1

Fix the errors and try again.\
Expand All @@ -609,6 +612,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire):
{
'event.name': 'gen_ai.user.message',
'content': """\
Validator response:
retry_prompt2

Fix the errors and try again.\
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def my_ret(x: int) -> str:
),
ModelResponse(
parts=[ToolCallPart(tool_name='my_ret', args={'x': 0}, tool_call_id=IsStr())],
usage=Usage(requests=1, request_tokens=61, response_tokens=8, total_tokens=69),
usage=Usage(requests=1, request_tokens=63, response_tokens=8, total_tokens=71),
model_name='test',
timestamp=IsNow(tz=timezone.utc),
),
Expand All @@ -134,7 +134,7 @@ async def my_ret(x: int) -> str:
),
ModelResponse(
parts=[TextPart(content='{"my_ret":"1"}')],
usage=Usage(requests=1, request_tokens=62, response_tokens=12, total_tokens=74),
usage=Usage(requests=1, request_tokens=64, response_tokens=12, total_tokens=76),
model_name='test',
timestamp=IsNow(tz=timezone.utc),
),
Expand Down
19 changes: 13 additions & 6 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse
),
ModelResponse(
parts=[ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr())],
usage=Usage(requests=1, request_tokens=87, response_tokens=14, total_tokens=101),
usage=Usage(requests=1, request_tokens=89, response_tokens=14, total_tokens=103),
model_name='function:return_model:',
timestamp=IsNow(tz=timezone.utc),
),
Expand Down Expand Up @@ -172,6 +172,7 @@ def check_b(cls, v: str) -> str:
retry_prompt = user_retry.parts[0]
assert isinstance(retry_prompt, RetryPromptPart)
assert retry_prompt.model_response() == snapshot("""\
Validator response:
1 validation errors: [
{
"type": "value_error",
Expand Down Expand Up @@ -229,7 +230,7 @@ def validate_output(ctx: RunContext[None], o: Foo) -> Foo:
),
ModelResponse(
parts=[ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr())],
usage=Usage(requests=1, request_tokens=63, response_tokens=14, total_tokens=77),
usage=Usage(requests=1, request_tokens=64, response_tokens=14, total_tokens=78),
model_name='function:return_model:',
timestamp=IsNow(tz=timezone.utc),
),
Expand Down Expand Up @@ -288,7 +289,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
parts=[
ToolCallPart(tool_name='final_result', args='{"response": ["foo", "bar"]}', tool_call_id=IsStr())
],
usage=Usage(requests=1, request_tokens=72, response_tokens=8, total_tokens=80),
usage=Usage(requests=1, request_tokens=74, response_tokens=8, total_tokens=82),
model_name='function:return_tuple:',
timestamp=IsNow(tz=timezone.utc),
),
Expand Down Expand Up @@ -828,7 +829,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
tool_call_id=IsStr(),
)
],
usage=Usage(requests=1, request_tokens=68, response_tokens=13, total_tokens=81),
usage=Usage(requests=1, request_tokens=70, response_tokens=13, total_tokens=83),
model_name='function:call_tool:',
timestamp=IsDatetime(),
),
Expand Down Expand Up @@ -1487,7 +1488,7 @@ def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
),
ModelResponse(
parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())],
usage=Usage(requests=1, request_tokens=65, response_tokens=4, total_tokens=69),
usage=Usage(requests=1, request_tokens=67, response_tokens=4, total_tokens=71),
model_name='function:empty:',
timestamp=IsNow(tz=timezone.utc),
),
Expand Down Expand Up @@ -1527,7 +1528,7 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
),
ModelResponse(
parts=[TextPart(content='success')],
usage=Usage(requests=1, request_tokens=65, response_tokens=3, total_tokens=68),
usage=Usage(requests=1, request_tokens=67, response_tokens=3, total_tokens=70),
model_name='function:empty:',
timestamp=IsNow(tz=timezone.utc),
),
Expand Down Expand Up @@ -2651,6 +2652,12 @@ def foo_tool(foo: Foo) -> int:
'tool_call_id': IsStr(),
'timestamp': IsStr(),
'part_kind': 'retry-prompt',
'model_response_template': """\
Validator response:
{description}

Fix the errors and try again.\
""",
}
],
'instructions': None,
Expand Down