Skip to content

Commit 925d2e0

Browse files
stevenhmaxkorp
authored andcommitted
chore(python-sdk): literal defaults
Set the event type and message role using pydantic Field and disable setting in the constructor to help prevent accidental misuse while reducing unnecessary boilerplate code. Replace Role type alias with an Enum to eliminate repeating string literals and align with EventType. Use Field(min_length=1) to validate TextMessageContentEvent delta, simplifying the code. Fixes #41
1 parent 5801458 commit 925d2e0

File tree

5 files changed

+43
-94
lines changed

5 files changed

+43
-94
lines changed

python-sdk/ag_ui/core/events.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, List, Literal, Optional, Union, Annotated
77
from pydantic import Field
88

9-
from .types import Message, State, ConfiguredBaseModel
9+
from .types import Message, State, ConfiguredBaseModel, Role
1010

1111

1212
class EventType(str, Enum):
@@ -52,36 +52,32 @@ class TextMessageStartEvent(BaseEvent):
5252
"""
5353
Event indicating the start of a text message.
5454
"""
55-
type: Literal[EventType.TEXT_MESSAGE_START]
55+
type: Literal[EventType.TEXT_MESSAGE_START] = Field(EventType.TEXT_MESSAGE_START, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
5656
message_id: str
57-
role: Literal["assistant"]
57+
role: Literal[Role.ASSISTANT] = Field(Role.ASSISTANT, init=False)
5858

5959

6060
class TextMessageContentEvent(BaseEvent):
6161
"""
6262
Event containing a piece of text message content.
6363
"""
64-
type: Literal[EventType.TEXT_MESSAGE_CONTENT]
64+
type: Literal[EventType.TEXT_MESSAGE_CONTENT] = Field(EventType.TEXT_MESSAGE_CONTENT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
6565
message_id: str
66-
delta: str # This should not be an empty string
67-
68-
def model_post_init(self, __context):
69-
if len(self.delta) == 0:
70-
raise ValueError("Delta must not be an empty string")
66+
delta: str = Field(min_length=1)
7167

7268

7369
class TextMessageEndEvent(BaseEvent):
7470
"""
7571
Event indicating the end of a text message.
7672
"""
77-
type: Literal[EventType.TEXT_MESSAGE_END]
73+
type: Literal[EventType.TEXT_MESSAGE_END] = Field(EventType.TEXT_MESSAGE_END, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
7874
message_id: str
7975

8076
class TextMessageChunkEvent(BaseEvent):
8177
"""
8278
Event containing a chunk of text message content.
8379
"""
84-
type: Literal[EventType.TEXT_MESSAGE_CHUNK]
80+
type: Literal[EventType.TEXT_MESSAGE_CHUNK] = Field(EventType.TEXT_MESSAGE_CHUNK, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
8581
message_id: Optional[str] = None
8682
role: Optional[Literal["assistant"]] = None
8783
delta: Optional[str] = None
@@ -113,7 +109,7 @@ class ToolCallStartEvent(BaseEvent):
113109
"""
114110
Event indicating the start of a tool call.
115111
"""
116-
type: Literal[EventType.TOOL_CALL_START]
112+
type: Literal[EventType.TOOL_CALL_START] = Field(EventType.TOOL_CALL_START, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
117113
tool_call_id: str
118114
tool_call_name: str
119115
parent_message_id: Optional[str] = None
@@ -123,7 +119,7 @@ class ToolCallArgsEvent(BaseEvent):
123119
"""
124120
Event containing tool call arguments.
125121
"""
126-
type: Literal[EventType.TOOL_CALL_ARGS]
122+
type: Literal[EventType.TOOL_CALL_ARGS] = Field(EventType.TOOL_CALL_ARGS, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
127123
tool_call_id: str
128124
delta: str
129125

@@ -132,14 +128,14 @@ class ToolCallEndEvent(BaseEvent):
132128
"""
133129
Event indicating the end of a tool call.
134130
"""
135-
type: Literal[EventType.TOOL_CALL_END]
131+
type: Literal[EventType.TOOL_CALL_END] = Field(EventType.TOOL_CALL_END, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
136132
tool_call_id: str
137133

138134
class ToolCallChunkEvent(BaseEvent):
139135
"""
140136
Event containing a chunk of tool call content.
141137
"""
142-
type: Literal[EventType.TOOL_CALL_CHUNK]
138+
type: Literal[EventType.TOOL_CALL_CHUNK] = Field(EventType.TOOL_CALL_CHUNK, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
143139
tool_call_id: Optional[str] = None
144140
tool_call_name: Optional[str] = None
145141
parent_message_id: Optional[str] = None
@@ -172,31 +168,31 @@ class StateSnapshotEvent(BaseEvent):
172168
"""
173169
Event containing a snapshot of the state.
174170
"""
175-
type: Literal[EventType.STATE_SNAPSHOT]
171+
type: Literal[EventType.STATE_SNAPSHOT] = Field(EventType.STATE_SNAPSHOT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
176172
snapshot: State
177173

178174

179175
class StateDeltaEvent(BaseEvent):
180176
"""
181177
Event containing a delta of the state.
182178
"""
183-
type: Literal[EventType.STATE_DELTA]
179+
type: Literal[EventType.STATE_DELTA] = Field(EventType.STATE_DELTA, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
184180
delta: List[Any] # JSON Patch (RFC 6902)
185181

186182

187183
class MessagesSnapshotEvent(BaseEvent):
188184
"""
189185
Event containing a snapshot of the messages.
190186
"""
191-
type: Literal[EventType.MESSAGES_SNAPSHOT]
187+
type: Literal[EventType.MESSAGES_SNAPSHOT] = Field(EventType.MESSAGES_SNAPSHOT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
192188
messages: List[Message]
193189

194190

195191
class RawEvent(BaseEvent):
196192
"""
197193
Event containing a raw event.
198194
"""
199-
type: Literal[EventType.RAW]
195+
type: Literal[EventType.RAW] = Field(EventType.RAW, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
200196
event: Any
201197
source: Optional[str] = None
202198

@@ -205,7 +201,7 @@ class CustomEvent(BaseEvent):
205201
"""
206202
Event containing a custom event.
207203
"""
208-
type: Literal[EventType.CUSTOM]
204+
type: Literal[EventType.CUSTOM] = Field(EventType.CUSTOM, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
209205
name: str
210206
value: Any
211207

@@ -214,7 +210,7 @@ class RunStartedEvent(BaseEvent):
214210
"""
215211
Event indicating that a run has started.
216212
"""
217-
type: Literal[EventType.RUN_STARTED]
213+
type: Literal[EventType.RUN_STARTED] = Field(EventType.RUN_STARTED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
218214
thread_id: str
219215
run_id: str
220216

@@ -223,7 +219,7 @@ class RunFinishedEvent(BaseEvent):
223219
"""
224220
Event indicating that a run has finished.
225221
"""
226-
type: Literal[EventType.RUN_FINISHED]
222+
type: Literal[EventType.RUN_FINISHED] = Field(EventType.RUN_FINISHED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
227223
thread_id: str
228224
run_id: str
229225
result: Optional[Any] = None
@@ -233,7 +229,7 @@ class RunErrorEvent(BaseEvent):
233229
"""
234230
Event indicating that a run has encountered an error.
235231
"""
236-
type: Literal[EventType.RUN_ERROR]
232+
type: Literal[EventType.RUN_ERROR] = Field(EventType.RUN_ERROR, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
237233
message: str
238234
code: Optional[str] = None
239235

@@ -242,15 +238,15 @@ class StepStartedEvent(BaseEvent):
242238
"""
243239
Event indicating that a step has started.
244240
"""
245-
type: Literal[EventType.STEP_STARTED]
241+
type: Literal[EventType.STEP_STARTED] = Field(EventType.STEP_STARTED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
246242
step_name: str
247243

248244

249245
class StepFinishedEvent(BaseEvent):
250246
"""
251247
Event indicating that a step has finished.
252248
"""
253-
type: Literal[EventType.STEP_FINISHED]
249+
type: Literal[EventType.STEP_FINISHED] = Field(EventType.STEP_FINISHED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
254250
step_name: str
255251

256252

python-sdk/ag_ui/core/types.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,23 @@
22
This module contains the types for the Agent User Interaction Protocol Python SDK.
33
"""
44

5+
from enum import Enum
56
from typing import Any, List, Literal, Optional, Union, Annotated
67
from pydantic import BaseModel, Field, ConfigDict
78
from pydantic.alias_generators import to_camel
89

10+
11+
class Role(str, Enum):
12+
"""
13+
The role of an actor.
14+
"""
15+
DEVELOPER = "developer"
16+
SYSTEM = "system"
17+
ASSISTANT = "assistant"
18+
USER = "user"
19+
TOOL = "tool"
20+
21+
922
class ConfiguredBaseModel(BaseModel):
1023
"""
1124
A configurable base model.
@@ -30,7 +43,7 @@ class ToolCall(ConfiguredBaseModel):
3043
A tool call, modelled after OpenAI tool calls.
3144
"""
3245
id: str
33-
type: Literal["function"]
46+
type: Literal["function"] = Field("function", init=False) # pyright: ignore[reportIncompatibleVariableOverride]
3447
function: FunctionCall
3548

3649

@@ -48,31 +61,30 @@ class DeveloperMessage(BaseMessage):
4861
"""
4962
A developer message.
5063
"""
51-
role: Literal["developer"]
64+
role: Literal[Role.DEVELOPER] = Field(Role.DEVELOPER, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
5265
content: str
5366

5467

5568
class SystemMessage(BaseMessage):
5669
"""
5770
A system message.
5871
"""
59-
role: Literal["system"]
72+
role: Literal[Role.SYSTEM] = Field(Role.SYSTEM, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
6073
content: str
6174

6275

6376
class AssistantMessage(BaseMessage):
6477
"""
6578
An assistant message.
6679
"""
67-
role: Literal["assistant"]
80+
role: Literal[Role.ASSISTANT] = Field(Role.ASSISTANT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
6881
tool_calls: Optional[List[ToolCall]] = None
6982

70-
7183
class UserMessage(BaseMessage):
7284
"""
7385
A user message.
7486
"""
75-
role: Literal["user"]
87+
role: Literal[Role.USER] = Field(Role.USER, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
7688
content: str
7789

7890

@@ -81,7 +93,7 @@ class ToolMessage(ConfiguredBaseModel):
8193
A tool result message.
8294
"""
8395
id: str
84-
role: Literal["tool"]
96+
role: Literal[Role.TOOL] = Field(Role.TOOL, init=False)
8597
content: str
8698
tool_call_id: str
8799

@@ -92,9 +104,6 @@ class ToolMessage(ConfiguredBaseModel):
92104
]
93105

94106

95-
Role = Literal["developer", "system", "assistant", "user", "tool"]
96-
97-
98107
class Context(ConfiguredBaseModel):
99108
"""
100109
Additional context for the agent.

python-sdk/tests/test_encoder.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def test_encode_sse_method(self):
4343
"""Test the encode_sse method"""
4444
# Create a test event with specific data
4545
event = TextMessageContentEvent(
46-
type=EventType.TEXT_MESSAGE_CONTENT,
4746
message_id="msg_123",
4847
delta="Hello, world!",
4948
timestamp=1648214400000
@@ -83,7 +82,6 @@ def test_encode_with_different_event_types(self):
8382

8483
# Test with a more complex event
8584
content_event = TextMessageContentEvent(
86-
type=EventType.TEXT_MESSAGE_CONTENT,
8785
message_id="msg_456",
8886
delta="Testing different events",
8987
timestamp=1648214400000
@@ -130,7 +128,6 @@ def test_null_value_exclusion(self):
130128
# Test with another event that has optional fields
131129
# Create event with some optional fields set to None
132130
event_with_optional = ToolCallStartEvent(
133-
type=EventType.TOOL_CALL_START,
134131
tool_call_id="call_123",
135132
tool_call_name="test_tool",
136133
parent_message_id=None, # Optional field explicitly set to None
@@ -152,7 +149,6 @@ def test_round_trip_serialization(self):
152149
"""Test that events can be serialized to JSON with camelCase and deserialized back correctly"""
153150
# Create a complex event with multiple fields
154151
original_event = ToolCallStartEvent(
155-
type=EventType.TOOL_CALL_START,
156152
tool_call_id="call_abc123",
157153
tool_call_name="search_tool",
158154
parent_message_id="msg_parent_456",

0 commit comments

Comments
 (0)