Skip to content

Commit 097f6fd

Browse files
committed
chore(python-sdk): literal defaults
Set the event type and message role using pydantic Field and disable setting in the constructor to help prevent accidental misuse while reducing unnecessary boilerplate code. Replace Role type alias with an Enum to eliminate repeating string literals and align with EventType. Fixes #41
1 parent 98dbd1a commit 097f6fd

File tree

5 files changed

+43
-94
lines changed

5 files changed

+43
-94
lines changed

python-sdk/ag_ui/core/events.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, List, Literal, Optional, Union, Annotated
77
from pydantic import Field
88

9-
from .types import Message, State, ConfiguredBaseModel
9+
from .types import Message, State, ConfiguredBaseModel, Role
1010

1111

1212
class EventType(str, Enum):
@@ -46,36 +46,32 @@ class TextMessageStartEvent(BaseEvent):
4646
"""
4747
Event indicating the start of a text message.
4848
"""
49-
type: Literal[EventType.TEXT_MESSAGE_START]
49+
type: Literal[EventType.TEXT_MESSAGE_START] = Field(EventType.TEXT_MESSAGE_START, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
5050
message_id: str
51-
role: Literal["assistant"]
51+
role: Literal[Role.ASSISTANT] = Field(Role.ASSISTANT, init=False)
5252

5353

5454
class TextMessageContentEvent(BaseEvent):
5555
"""
5656
Event containing a piece of text message content.
5757
"""
58-
type: Literal[EventType.TEXT_MESSAGE_CONTENT]
58+
type: Literal[EventType.TEXT_MESSAGE_CONTENT] = Field(EventType.TEXT_MESSAGE_CONTENT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
5959
message_id: str
60-
delta: str # This should not be an empty string
61-
62-
def model_post_init(self, __context):
63-
if len(self.delta) == 0:
64-
raise ValueError("Delta must not be an empty string")
60+
delta: str = Field(min_length=1) # This should not be an empty string
6561

6662

6763
class TextMessageEndEvent(BaseEvent):
6864
"""
6965
Event indicating the end of a text message.
7066
"""
71-
type: Literal[EventType.TEXT_MESSAGE_END]
67+
type: Literal[EventType.TEXT_MESSAGE_END] = Field(EventType.TEXT_MESSAGE_END, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
7268
message_id: str
7369

7470
class TextMessageChunkEvent(BaseEvent):
7571
"""
7672
Event containing a chunk of text message content.
7773
"""
78-
type: Literal[EventType.TEXT_MESSAGE_CHUNK]
74+
type: Literal[EventType.TEXT_MESSAGE_CHUNK] = Field(EventType.TEXT_MESSAGE_CHUNK, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
7975
message_id: Optional[str] = None
8076
role: Optional[Literal["assistant"]] = None
8177
delta: Optional[str] = None
@@ -84,7 +80,7 @@ class ToolCallStartEvent(BaseEvent):
8480
"""
8581
Event indicating the start of a tool call.
8682
"""
87-
type: Literal[EventType.TOOL_CALL_START]
83+
type: Literal[EventType.TOOL_CALL_START] = Field(EventType.TOOL_CALL_START, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
8884
tool_call_id: str
8985
tool_call_name: str
9086
parent_message_id: Optional[str] = None
@@ -94,7 +90,7 @@ class ToolCallArgsEvent(BaseEvent):
9490
"""
9591
Event containing tool call arguments.
9692
"""
97-
type: Literal[EventType.TOOL_CALL_ARGS]
93+
type: Literal[EventType.TOOL_CALL_ARGS] = Field(EventType.TOOL_CALL_ARGS, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
9894
tool_call_id: str
9995
delta: str
10096

@@ -103,14 +99,14 @@ class ToolCallEndEvent(BaseEvent):
10399
"""
104100
Event indicating the end of a tool call.
105101
"""
106-
type: Literal[EventType.TOOL_CALL_END]
102+
type: Literal[EventType.TOOL_CALL_END] = Field(EventType.TOOL_CALL_END, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
107103
tool_call_id: str
108104

109105
class ToolCallChunkEvent(BaseEvent):
110106
"""
111107
Event containing a chunk of tool call content.
112108
"""
113-
type: Literal[EventType.TOOL_CALL_CHUNK]
109+
type: Literal[EventType.TOOL_CALL_CHUNK] = Field(EventType.TOOL_CALL_CHUNK, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
114110
tool_call_id: Optional[str] = None
115111
tool_call_name: Optional[str] = None
116112
parent_message_id: Optional[str] = None
@@ -120,31 +116,31 @@ class StateSnapshotEvent(BaseEvent):
120116
"""
121117
Event containing a snapshot of the state.
122118
"""
123-
type: Literal[EventType.STATE_SNAPSHOT]
119+
type: Literal[EventType.STATE_SNAPSHOT] = Field(EventType.STATE_SNAPSHOT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
124120
snapshot: State
125121

126122

127123
class StateDeltaEvent(BaseEvent):
128124
"""
129125
Event containing a delta of the state.
130126
"""
131-
type: Literal[EventType.STATE_DELTA]
127+
type: Literal[EventType.STATE_DELTA] = Field(EventType.STATE_DELTA, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
132128
delta: List[Any] # JSON Patch (RFC 6902)
133129

134130

135131
class MessagesSnapshotEvent(BaseEvent):
136132
"""
137133
Event containing a snapshot of the messages.
138134
"""
139-
type: Literal[EventType.MESSAGES_SNAPSHOT]
135+
type: Literal[EventType.MESSAGES_SNAPSHOT] = Field(EventType.MESSAGES_SNAPSHOT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
140136
messages: List[Message]
141137

142138

143139
class RawEvent(BaseEvent):
144140
"""
145141
Event containing a raw event.
146142
"""
147-
type: Literal[EventType.RAW]
143+
type: Literal[EventType.RAW] = Field(EventType.RAW, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
148144
event: Any
149145
source: Optional[str] = None
150146

@@ -153,7 +149,7 @@ class CustomEvent(BaseEvent):
153149
"""
154150
Event containing a custom event.
155151
"""
156-
type: Literal[EventType.CUSTOM]
152+
type: Literal[EventType.CUSTOM] = Field(EventType.CUSTOM, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
157153
name: str
158154
value: Any
159155

@@ -162,7 +158,7 @@ class RunStartedEvent(BaseEvent):
162158
"""
163159
Event indicating that a run has started.
164160
"""
165-
type: Literal[EventType.RUN_STARTED]
161+
type: Literal[EventType.RUN_STARTED] = Field(EventType.RUN_STARTED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
166162
thread_id: str
167163
run_id: str
168164

@@ -171,7 +167,7 @@ class RunFinishedEvent(BaseEvent):
171167
"""
172168
Event indicating that a run has finished.
173169
"""
174-
type: Literal[EventType.RUN_FINISHED]
170+
type: Literal[EventType.RUN_FINISHED] = Field(EventType.RUN_FINISHED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
175171
thread_id: str
176172
run_id: str
177173

@@ -180,7 +176,7 @@ class RunErrorEvent(BaseEvent):
180176
"""
181177
Event indicating that a run has encountered an error.
182178
"""
183-
type: Literal[EventType.RUN_ERROR]
179+
type: Literal[EventType.RUN_ERROR] = Field(EventType.RUN_ERROR, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
184180
message: str
185181
code: Optional[str] = None
186182

@@ -189,15 +185,15 @@ class StepStartedEvent(BaseEvent):
189185
"""
190186
Event indicating that a step has started.
191187
"""
192-
type: Literal[EventType.STEP_STARTED]
188+
type: Literal[EventType.STEP_STARTED] = Field(EventType.STEP_STARTED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
193189
step_name: str
194190

195191

196192
class StepFinishedEvent(BaseEvent):
197193
"""
198194
Event indicating that a step has finished.
199195
"""
200-
type: Literal[EventType.STEP_FINISHED]
196+
type: Literal[EventType.STEP_FINISHED] = Field(EventType.STEP_FINISHED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
201197
step_name: str
202198

203199

python-sdk/ag_ui/core/types.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,23 @@
22
This module contains the types for the Agent User Interaction Protocol Python SDK.
33
"""
44

5+
from enum import Enum
56
from typing import Any, List, Literal, Optional, Union, Annotated
67
from pydantic import BaseModel, Field, ConfigDict
78
from pydantic.alias_generators import to_camel
89

10+
11+
class Role(str, Enum):
12+
"""
13+
The role of an actor.
14+
"""
15+
DEVELOPER = "developer"
16+
SYSTEM = "system"
17+
ASSISTANT = "assistant"
18+
USER = "user"
19+
TOOL = "tool"
20+
21+
922
class ConfiguredBaseModel(BaseModel):
1023
"""
1124
A configurable base model.
@@ -31,7 +44,7 @@ class ToolCall(ConfiguredBaseModel):
3144
A tool call, modelled after OpenAI tool calls.
3245
"""
3346
id: str
34-
type: Literal["function"]
47+
type: Literal["function"] = Field("function", init=False) # pyright: ignore[reportIncompatibleVariableOverride]
3548
function: FunctionCall
3649

3750

@@ -49,31 +62,30 @@ class DeveloperMessage(BaseMessage):
4962
"""
5063
A developer message.
5164
"""
52-
role: Literal["developer"]
65+
role: Literal[Role.DEVELOPER] = Field(Role.DEVELOPER, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
5366
content: str
5467

5568

5669
class SystemMessage(BaseMessage):
5770
"""
5871
A system message.
5972
"""
60-
role: Literal["system"]
73+
role: Literal[Role.SYSTEM] = Field(Role.SYSTEM, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
6174
content: str
6275

6376

6477
class AssistantMessage(BaseMessage):
6578
"""
6679
An assistant message.
6780
"""
68-
role: Literal["assistant"]
81+
role: Literal[Role.ASSISTANT] = Field(Role.ASSISTANT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
6982
tool_calls: Optional[List[ToolCall]] = None
7083

71-
7284
class UserMessage(BaseMessage):
7385
"""
7486
A user message.
7587
"""
76-
role: Literal["user"]
88+
role: Literal[Role.USER] = Field(Role.USER, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
7789
content: str
7890

7991

@@ -82,7 +94,7 @@ class ToolMessage(ConfiguredBaseModel):
8294
A tool result message.
8395
"""
8496
id: str
85-
role: Literal["tool"]
97+
role: Literal[Role.TOOL] = Field(Role.TOOL, init=False)
8698
content: str
8799
tool_call_id: str
88100

@@ -93,9 +105,6 @@ class ToolMessage(ConfiguredBaseModel):
93105
]
94106

95107

96-
Role = Literal["developer", "system", "assistant", "user", "tool"]
97-
98-
99108
class Context(ConfiguredBaseModel):
100109
"""
101110
Additional context for the agent.

python-sdk/tests/test_encoder.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def test_encode_sse_method(self):
4343
"""Test the encode_sse method"""
4444
# Create a test event with specific data
4545
event = TextMessageContentEvent(
46-
type=EventType.TEXT_MESSAGE_CONTENT,
4746
message_id="msg_123",
4847
delta="Hello, world!",
4948
timestamp=1648214400000
@@ -83,7 +82,6 @@ def test_encode_with_different_event_types(self):
8382

8483
# Test with a more complex event
8584
content_event = TextMessageContentEvent(
86-
type=EventType.TEXT_MESSAGE_CONTENT,
8785
message_id="msg_456",
8886
delta="Testing different events",
8987
timestamp=1648214400000
@@ -130,7 +128,6 @@ def test_null_value_exclusion(self):
130128
# Test with another event that has optional fields
131129
# Create event with some optional fields set to None
132130
event_with_optional = ToolCallStartEvent(
133-
type=EventType.TOOL_CALL_START,
134131
tool_call_id="call_123",
135132
tool_call_name="test_tool",
136133
parent_message_id=None, # Optional field explicitly set to None
@@ -152,7 +149,6 @@ def test_round_trip_serialization(self):
152149
"""Test that events can be serialized to JSON with camelCase and deserialized back correctly"""
153150
# Create a complex event with multiple fields
154151
original_event = ToolCallStartEvent(
155-
type=EventType.TOOL_CALL_START,
156152
tool_call_id="call_abc123",
157153
tool_call_name="search_tool",
158154
parent_message_id="msg_parent_456",

0 commit comments

Comments
 (0)