Skip to content

chore(python-sdk): literal defaults #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 21 additions & 25 deletions python-sdk/ag_ui/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, List, Literal, Optional, Union, Annotated
from pydantic import Field

from .types import Message, State, ConfiguredBaseModel
from .types import Message, State, ConfiguredBaseModel, Role


class EventType(str, Enum):
Expand Down Expand Up @@ -46,36 +46,32 @@ class TextMessageStartEvent(BaseEvent):
"""
Event indicating the start of a text message.
"""
type: Literal[EventType.TEXT_MESSAGE_START]
type: Literal[EventType.TEXT_MESSAGE_START] = Field(EventType.TEXT_MESSAGE_START, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
message_id: str
role: Literal["assistant"]
role: Literal[Role.ASSISTANT] = Field(Role.ASSISTANT, init=False)


class TextMessageContentEvent(BaseEvent):
"""
Event containing a piece of text message content.
"""
type: Literal[EventType.TEXT_MESSAGE_CONTENT]
type: Literal[EventType.TEXT_MESSAGE_CONTENT] = Field(EventType.TEXT_MESSAGE_CONTENT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
message_id: str
delta: str # This should not be an empty string

def model_post_init(self, __context):
if len(self.delta) == 0:
raise ValueError("Delta must not be an empty string")
delta: str = Field(min_length=1)


class TextMessageEndEvent(BaseEvent):
"""
Event indicating the end of a text message.
"""
type: Literal[EventType.TEXT_MESSAGE_END]
type: Literal[EventType.TEXT_MESSAGE_END] = Field(EventType.TEXT_MESSAGE_END, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
message_id: str

class TextMessageChunkEvent(BaseEvent):
"""
Event containing a chunk of text message content.
"""
type: Literal[EventType.TEXT_MESSAGE_CHUNK]
type: Literal[EventType.TEXT_MESSAGE_CHUNK] = Field(EventType.TEXT_MESSAGE_CHUNK, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
message_id: Optional[str] = None
role: Optional[Literal["assistant"]] = None
delta: Optional[str] = None
Expand All @@ -84,7 +80,7 @@ class ToolCallStartEvent(BaseEvent):
"""
Event indicating the start of a tool call.
"""
type: Literal[EventType.TOOL_CALL_START]
type: Literal[EventType.TOOL_CALL_START] = Field(EventType.TOOL_CALL_START, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
tool_call_id: str
tool_call_name: str
parent_message_id: Optional[str] = None
Expand All @@ -94,7 +90,7 @@ class ToolCallArgsEvent(BaseEvent):
"""
Event containing tool call arguments.
"""
type: Literal[EventType.TOOL_CALL_ARGS]
type: Literal[EventType.TOOL_CALL_ARGS] = Field(EventType.TOOL_CALL_ARGS, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
tool_call_id: str
delta: str

Expand All @@ -103,14 +99,14 @@ class ToolCallEndEvent(BaseEvent):
"""
Event indicating the end of a tool call.
"""
type: Literal[EventType.TOOL_CALL_END]
type: Literal[EventType.TOOL_CALL_END] = Field(EventType.TOOL_CALL_END, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
tool_call_id: str

class ToolCallChunkEvent(BaseEvent):
"""
Event containing a chunk of tool call content.
"""
type: Literal[EventType.TOOL_CALL_CHUNK]
type: Literal[EventType.TOOL_CALL_CHUNK] = Field(EventType.TOOL_CALL_CHUNK, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
tool_call_id: Optional[str] = None
tool_call_name: Optional[str] = None
parent_message_id: Optional[str] = None
Expand All @@ -120,31 +116,31 @@ class StateSnapshotEvent(BaseEvent):
"""
Event containing a snapshot of the state.
"""
type: Literal[EventType.STATE_SNAPSHOT]
type: Literal[EventType.STATE_SNAPSHOT] = Field(EventType.STATE_SNAPSHOT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
snapshot: State


class StateDeltaEvent(BaseEvent):
"""
Event containing a delta of the state.
"""
type: Literal[EventType.STATE_DELTA]
type: Literal[EventType.STATE_DELTA] = Field(EventType.STATE_DELTA, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
delta: List[Any] # JSON Patch (RFC 6902)


class MessagesSnapshotEvent(BaseEvent):
"""
Event containing a snapshot of the messages.
"""
type: Literal[EventType.MESSAGES_SNAPSHOT]
type: Literal[EventType.MESSAGES_SNAPSHOT] = Field(EventType.MESSAGES_SNAPSHOT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
messages: List[Message]


class RawEvent(BaseEvent):
"""
Event containing a raw event.
"""
type: Literal[EventType.RAW]
type: Literal[EventType.RAW] = Field(EventType.RAW, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
event: Any
source: Optional[str] = None

Expand All @@ -153,7 +149,7 @@ class CustomEvent(BaseEvent):
"""
Event containing a custom event.
"""
type: Literal[EventType.CUSTOM]
type: Literal[EventType.CUSTOM] = Field(EventType.CUSTOM, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
name: str
value: Any

Expand All @@ -162,7 +158,7 @@ class RunStartedEvent(BaseEvent):
"""
Event indicating that a run has started.
"""
type: Literal[EventType.RUN_STARTED]
type: Literal[EventType.RUN_STARTED] = Field(EventType.RUN_STARTED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
thread_id: str
run_id: str

Expand All @@ -171,7 +167,7 @@ class RunFinishedEvent(BaseEvent):
"""
Event indicating that a run has finished.
"""
type: Literal[EventType.RUN_FINISHED]
type: Literal[EventType.RUN_FINISHED] = Field(EventType.RUN_FINISHED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
thread_id: str
run_id: str

Expand All @@ -180,7 +176,7 @@ class RunErrorEvent(BaseEvent):
"""
Event indicating that a run has encountered an error.
"""
type: Literal[EventType.RUN_ERROR]
type: Literal[EventType.RUN_ERROR] = Field(EventType.RUN_ERROR, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
message: str
code: Optional[str] = None

Expand All @@ -189,15 +185,15 @@ class StepStartedEvent(BaseEvent):
"""
Event indicating that a step has started.
"""
type: Literal[EventType.STEP_STARTED]
type: Literal[EventType.STEP_STARTED] = Field(EventType.STEP_STARTED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
step_name: str


class StepFinishedEvent(BaseEvent):
"""
Event indicating that a step has finished.
"""
type: Literal[EventType.STEP_FINISHED]
type: Literal[EventType.STEP_FINISHED] = Field(EventType.STEP_FINISHED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
step_name: str


Expand Down
29 changes: 19 additions & 10 deletions python-sdk/ag_ui/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,23 @@
This module contains the types for the Agent User Interaction Protocol Python SDK.
"""

from enum import Enum
from typing import Any, List, Literal, Optional, Union, Annotated
from pydantic import BaseModel, Field, ConfigDict
from pydantic.alias_generators import to_camel


class Role(str, Enum):
"""
The role of an actor.
"""
DEVELOPER = "developer"
SYSTEM = "system"
ASSISTANT = "assistant"
USER = "user"
TOOL = "tool"


class ConfiguredBaseModel(BaseModel):
"""
A configurable base model.
Expand All @@ -31,7 +44,7 @@ class ToolCall(ConfiguredBaseModel):
A tool call, modelled after OpenAI tool calls.
"""
id: str
type: Literal["function"]
type: Literal["function"] = Field("function", init=False) # pyright: ignore[reportIncompatibleVariableOverride]
function: FunctionCall


Expand All @@ -49,31 +62,30 @@ class DeveloperMessage(BaseMessage):
"""
A developer message.
"""
role: Literal["developer"]
role: Literal[Role.DEVELOPER] = Field(Role.DEVELOPER, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
content: str


class SystemMessage(BaseMessage):
"""
A system message.
"""
role: Literal["system"]
role: Literal[Role.SYSTEM] = Field(Role.SYSTEM, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
content: str


class AssistantMessage(BaseMessage):
"""
An assistant message.
"""
role: Literal["assistant"]
role: Literal[Role.ASSISTANT] = Field(Role.ASSISTANT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
tool_calls: Optional[List[ToolCall]] = None


class UserMessage(BaseMessage):
"""
A user message.
"""
role: Literal["user"]
role: Literal[Role.USER] = Field(Role.USER, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
content: str


Expand All @@ -82,7 +94,7 @@ class ToolMessage(ConfiguredBaseModel):
A tool result message.
"""
id: str
role: Literal["tool"]
role: Literal[Role.TOOL] = Field(Role.TOOL, init=False)
content: str
tool_call_id: str

Expand All @@ -93,9 +105,6 @@ class ToolMessage(ConfiguredBaseModel):
]


Role = Literal["developer", "system", "assistant", "user", "tool"]


class Context(ConfiguredBaseModel):
"""
Additional context for the agent.
Expand Down
4 changes: 0 additions & 4 deletions python-sdk/tests/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def test_encode_sse_method(self):
"""Test the encode_sse method"""
# Create a test event with specific data
event = TextMessageContentEvent(
type=EventType.TEXT_MESSAGE_CONTENT,
message_id="msg_123",
delta="Hello, world!",
timestamp=1648214400000
Expand Down Expand Up @@ -83,7 +82,6 @@ def test_encode_with_different_event_types(self):

# Test with a more complex event
content_event = TextMessageContentEvent(
type=EventType.TEXT_MESSAGE_CONTENT,
message_id="msg_456",
delta="Testing different events",
timestamp=1648214400000
Expand Down Expand Up @@ -130,7 +128,6 @@ def test_null_value_exclusion(self):
# Test with another event that has optional fields
# Create event with some optional fields set to None
event_with_optional = ToolCallStartEvent(
type=EventType.TOOL_CALL_START,
tool_call_id="call_123",
tool_call_name="test_tool",
parent_message_id=None, # Optional field explicitly set to None
Expand All @@ -152,7 +149,6 @@ def test_round_trip_serialization(self):
"""Test that events can be serialized to JSON with camelCase and deserialized back correctly"""
# Create a complex event with multiple fields
original_event = ToolCallStartEvent(
type=EventType.TOOL_CALL_START,
tool_call_id="call_abc123",
tool_call_name="search_tool",
parent_message_id="msg_parent_456",
Expand Down
Loading