Skip to content

Commit 131d793

Browse files
committed
refactor(ag-ui): push encode to the top level
Push encode to the top level, eliminating the need to pass the encoder to lower levels which simplifies the code and makes it more maintainable.
1 parent 8d28862 commit 131d793

File tree

1 file changed

+43
-62
lines changed

1 file changed

+43
-62
lines changed

pydantic_ai_ag_ui/pydantic_ai_ag_ui/adapter.py

Lines changed: 43 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class _RequestStreamContext:
8585

8686
message_id: str = ''
8787
last_tool_call_id: str | None = None
88-
part_ends: list[str | None] = field(default_factory=lambda: list[str | None]())
88+
part_ends: list[BaseEvent | None] = field(default_factory=lambda: list[BaseEvent | None]())
8989
local_tool_calls: set[str] = field(default_factory=set)
9090

9191
def new_message_id(self) -> str:
@@ -257,13 +257,13 @@ async def run(
257257
infer_name=infer_name,
258258
additional_tools=run_tools,
259259
) as run:
260-
async for event in self._agent_stream(encoder, tool_names, run):
260+
async for event in self._agent_stream(tool_names, run):
261261
if event is None:
262262
# Tool call signals early return, so we stop processing.
263263
self.logger.debug('tool call early return')
264264
break
265265

266-
yield event
266+
yield encoder.encode(event)
267267
except RunError as e:
268268
self.logger.exception('agent run')
269269
yield encoder.encode(
@@ -285,9 +285,7 @@ async def run(
285285

286286
self.logger.info('done thread_id=%s run_id=%s', run_input.thread_id, run_input.run_id)
287287

288-
async def _tool_events(
289-
self, encoder: EventEncoder, parts: list[ModelRequestPart]
290-
) -> AsyncGenerator[str | None, None]:
288+
async def _tool_events(self, parts: list[ModelRequestPart]) -> AsyncGenerator[BaseEvent | None, None]:
291289
"""Check for tool call results that are AG-UI events.
292290
293291
Args:
@@ -309,15 +307,15 @@ async def _tool_events(
309307
match part.content:
310308
case BaseEvent():
311309
self.logger.debug('ag-ui event: %s', part.content)
312-
yield encoder.encode(part.content)
310+
yield part.content
313311
case str() | bytes():
314312
# Avoid strings and bytes being checked as iterable.
315313
pass
316314
case Iterable() as iter:
317315
for item in iter:
318316
if isinstance(item, BaseEvent): # pragma: no branch
319317
self.logger.debug('ag-ui event: %s', item)
320-
yield encoder.encode(item)
318+
yield item
321319
case _: # pragma: no cover
322320
# Not currently interested in other types.
323321
pass
@@ -371,51 +369,48 @@ def _tool_stub(*args: Any, **kwargs: Any) -> ToolResult:
371369

372370
async def _agent_stream(
373371
self,
374-
encoder: EventEncoder,
375372
tool_names: dict[str, str],
376373
run: AgentRun[AgentDepsT, Any],
377-
) -> AsyncGenerator[str | None, None]:
374+
) -> AsyncGenerator[BaseEvent | None, None]:
378375
"""Run the agent streaming responses using AG-UI protocol events.
379376
380377
Args:
381-
encoder: The event encoder to use for encoding events.
382378
tool_names: A mapping of tool names to their AG-UI names.
383379
run: The agent run to process.
384380
385381
Yields:
386382
AG-UI Server-Sent Events (SSE).
387383
"""
388384
node: AgentNode[AgentDepsT, Any] | End[FinalResult[Any]]
389-
msg: str | None
385+
msg: BaseEvent | None
390386
async for node in run:
391387
self.logger.debug('processing node=%r', node)
392388
if not isinstance(node, ModelRequestNode):
393389
# Not interested UserPromptNode, CallToolsNode or End.
394390
continue
395391

396392
# Check for state updates.
397-
snapshot: str | None
398-
async for snapshot in self._tool_events(encoder, node.request.parts):
393+
snapshot: BaseEvent | None
394+
async for snapshot in self._tool_events(node.request.parts):
399395
yield snapshot
400396

401397
stream_ctx: _RequestStreamContext = _RequestStreamContext()
402398
request_stream: AgentStream[AgentDepsT]
403399
async with node.stream(run.ctx) as request_stream:
404400
agent_event: AgentStreamEvent
405401
async for agent_event in request_stream:
406-
async for msg in self._handle_agent_event(encoder, tool_names, stream_ctx, agent_event):
402+
async for msg in self._handle_agent_event(tool_names, stream_ctx, agent_event):
407403
yield msg
408404

409405
for part_end in stream_ctx.part_ends:
410406
yield part_end
411407

412408
async def _handle_agent_event(
413409
self,
414-
encoder: EventEncoder,
415410
tool_names: dict[str, str],
416411
stream_ctx: _RequestStreamContext,
417412
agent_event: AgentStreamEvent,
418-
) -> AsyncGenerator[str | None, None]:
413+
) -> AsyncGenerator[BaseEvent | None, None]:
419414
"""Handle an agent event and yield AG-UI protocol events.
420415
421416
Args:
@@ -431,36 +426,30 @@ async def _handle_agent_event(
431426
match agent_event:
432427
case PartStartEvent():
433428
# If we have a previous part end it.
434-
part_end: str | None
429+
part_end: BaseEvent | None
435430
for part_end in stream_ctx.part_ends:
436431
yield part_end
437432
stream_ctx.part_ends.clear()
438433

439434
match agent_event.part:
440435
case TextPart():
441436
message_id: str = stream_ctx.new_message_id()
442-
yield encoder.encode(
443-
TextMessageStartEvent(
444-
type=EventType.TEXT_MESSAGE_START,
445-
message_id=message_id,
446-
role=Role.ASSISTANT.value,
447-
),
437+
yield TextMessageStartEvent(
438+
type=EventType.TEXT_MESSAGE_START,
439+
message_id=message_id,
440+
role=Role.ASSISTANT.value,
448441
)
449442
stream_ctx.part_ends = [
450-
encoder.encode(
451-
TextMessageEndEvent(
452-
type=EventType.TEXT_MESSAGE_END,
453-
message_id=message_id,
454-
),
443+
TextMessageEndEvent(
444+
type=EventType.TEXT_MESSAGE_END,
445+
message_id=message_id,
455446
),
456447
]
457448
if agent_event.part.content:
458-
yield encoder.encode( # pragma: no cover
459-
TextMessageContentEvent(
460-
type=EventType.TEXT_MESSAGE_CONTENT,
461-
message_id=message_id,
462-
delta=agent_event.part.content,
463-
),
449+
yield TextMessageContentEvent( # pragma: no cover
450+
type=EventType.TEXT_MESSAGE_CONTENT,
451+
message_id=message_id,
452+
delta=agent_event.part.content,
464453
)
465454
case ToolCallPart(): # pragma: no branch
466455
tool_name: str | None = tool_names.get(agent_event.part.tool_name)
@@ -469,19 +458,15 @@ async def _handle_agent_event(
469458
return
470459

471460
stream_ctx.last_tool_call_id = agent_event.part.tool_call_id
472-
yield encoder.encode(
473-
ToolCallStartEvent(
474-
type=EventType.TOOL_CALL_START,
475-
tool_call_id=agent_event.part.tool_call_id,
476-
tool_call_name=tool_name or agent_event.part.tool_name,
477-
),
461+
yield ToolCallStartEvent(
462+
type=EventType.TOOL_CALL_START,
463+
tool_call_id=agent_event.part.tool_call_id,
464+
tool_call_name=tool_name or agent_event.part.tool_name,
478465
)
479466
stream_ctx.part_ends = [
480-
encoder.encode(
481-
ToolCallEndEvent(
482-
type=EventType.TOOL_CALL_END,
483-
tool_call_id=agent_event.part.tool_call_id,
484-
),
467+
ToolCallEndEvent(
468+
type=EventType.TOOL_CALL_END,
469+
tool_call_id=agent_event.part.tool_call_id,
485470
),
486471
None, # Signal continuation of the stream.
487472
]
@@ -491,28 +476,24 @@ async def _handle_agent_event(
491476
case PartDeltaEvent():
492477
match agent_event.delta:
493478
case TextPartDelta():
494-
yield encoder.encode(
495-
TextMessageContentEvent(
496-
type=EventType.TEXT_MESSAGE_CONTENT,
497-
message_id=stream_ctx.message_id,
498-
delta=agent_event.delta.content_delta,
499-
),
479+
yield TextMessageContentEvent(
480+
type=EventType.TEXT_MESSAGE_CONTENT,
481+
message_id=stream_ctx.message_id,
482+
delta=agent_event.delta.content_delta,
500483
)
501484
case ToolCallPartDelta(): # pragma: no branch
502485
if agent_event.delta.tool_call_id in stream_ctx.local_tool_calls:
503486
# Local tool calls are not sent to the UI.
504487
return
505488

506-
yield encoder.encode(
507-
ToolCallArgsEvent(
508-
type=EventType.TOOL_CALL_ARGS,
509-
tool_call_id=agent_event.delta.tool_call_id
510-
or stream_ctx.last_tool_call_id
511-
or 'unknown', # Should never be unknown, but just in case.
512-
delta=agent_event.delta.args_delta
513-
if isinstance(agent_event.delta.args_delta, str)
514-
else json.dumps(agent_event.delta.args_delta),
515-
),
489+
yield ToolCallArgsEvent(
490+
type=EventType.TOOL_CALL_ARGS,
491+
tool_call_id=agent_event.delta.tool_call_id
492+
or stream_ctx.last_tool_call_id
493+
or 'unknown', # Should never be unknown, but just in case.
494+
delta=agent_event.delta.args_delta
495+
if isinstance(agent_event.delta.args_delta, str)
496+
else json.dumps(agent_event.delta.args_delta),
516497
)
517498
case ThinkingPartDelta(): # pragma: no branch
518499
# No equivalent AG-UI event yet.

0 commit comments

Comments
 (0)