Skip to content

Commit 49a586d

Browse files
stevenhmaxkorp
authored andcommitted
chore: remove init=False from default literals
Remove `init=False` from default literals to improve compatibility with existing code. Minimise changes by using literal strings instead of enum.
1 parent 925d2e0 commit 49a586d

File tree

2 files changed

+34
-41
lines changed

2 files changed

+34
-41
lines changed

python-sdk/ag_ui/core/events.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
"""
44

55
from enum import Enum
6-
from typing import Any, List, Literal, Optional, Union, Annotated
6+
from typing import Annotated, Any, List, Literal, Optional, Union
7+
78
from pydantic import Field
89

9-
from .types import Message, State, ConfiguredBaseModel, Role
10+
from .types import ConfiguredBaseModel, Message, State
1011

1112

1213
class EventType(str, Enum):
@@ -52,16 +53,16 @@ class TextMessageStartEvent(BaseEvent):
5253
"""
5354
Event indicating the start of a text message.
5455
"""
55-
type: Literal[EventType.TEXT_MESSAGE_START] = Field(EventType.TEXT_MESSAGE_START, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
56+
type: Literal[EventType.TEXT_MESSAGE_START] = EventType.TEXT_MESSAGE_START # pyright: ignore[reportIncompatibleVariableOverride]
5657
message_id: str
57-
role: Literal[Role.ASSISTANT] = Field(Role.ASSISTANT, init=False)
58+
role: Literal["assistant"] = "assistant"
5859

5960

6061
class TextMessageContentEvent(BaseEvent):
6162
"""
6263
Event containing a piece of text message content.
6364
"""
64-
type: Literal[EventType.TEXT_MESSAGE_CONTENT] = Field(EventType.TEXT_MESSAGE_CONTENT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
65+
type: Literal[EventType.TEXT_MESSAGE_CONTENT] = EventType.TEXT_MESSAGE_CONTENT # pyright: ignore[reportIncompatibleVariableOverride]
6566
message_id: str
6667
delta: str = Field(min_length=1)
6768

@@ -70,14 +71,14 @@ class TextMessageEndEvent(BaseEvent):
7071
"""
7172
Event indicating the end of a text message.
7273
"""
73-
type: Literal[EventType.TEXT_MESSAGE_END] = Field(EventType.TEXT_MESSAGE_END, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
74+
type: Literal[EventType.TEXT_MESSAGE_END] = EventType.TEXT_MESSAGE_END # pyright: ignore[reportIncompatibleVariableOverride]
7475
message_id: str
7576

7677
class TextMessageChunkEvent(BaseEvent):
7778
"""
7879
Event containing a chunk of text message content.
7980
"""
80-
type: Literal[EventType.TEXT_MESSAGE_CHUNK] = Field(EventType.TEXT_MESSAGE_CHUNK, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
81+
type: Literal[EventType.TEXT_MESSAGE_CHUNK] = EventType.TEXT_MESSAGE_CHUNK # pyright: ignore[reportIncompatibleVariableOverride]
8182
message_id: Optional[str] = None
8283
role: Optional[Literal["assistant"]] = None
8384
delta: Optional[str] = None
@@ -109,7 +110,7 @@ class ToolCallStartEvent(BaseEvent):
109110
"""
110111
Event indicating the start of a tool call.
111112
"""
112-
type: Literal[EventType.TOOL_CALL_START] = Field(EventType.TOOL_CALL_START, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
113+
type: Literal[EventType.TOOL_CALL_START] = EventType.TOOL_CALL_START # pyright: ignore[reportIncompatibleVariableOverride]
113114
tool_call_id: str
114115
tool_call_name: str
115116
parent_message_id: Optional[str] = None
@@ -119,7 +120,7 @@ class ToolCallArgsEvent(BaseEvent):
119120
"""
120121
Event containing tool call arguments.
121122
"""
122-
type: Literal[EventType.TOOL_CALL_ARGS] = Field(EventType.TOOL_CALL_ARGS, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
123+
type: Literal[EventType.TOOL_CALL_ARGS] = EventType.TOOL_CALL_ARGS # pyright: ignore[reportIncompatibleVariableOverride]
123124
tool_call_id: str
124125
delta: str
125126

@@ -128,14 +129,14 @@ class ToolCallEndEvent(BaseEvent):
128129
"""
129130
Event indicating the end of a tool call.
130131
"""
131-
type: Literal[EventType.TOOL_CALL_END] = Field(EventType.TOOL_CALL_END, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
132+
type: Literal[EventType.TOOL_CALL_END] = EventType.TOOL_CALL_END # pyright: ignore[reportIncompatibleVariableOverride]
132133
tool_call_id: str
133134

134135
class ToolCallChunkEvent(BaseEvent):
135136
"""
136137
Event containing a chunk of tool call content.
137138
"""
138-
type: Literal[EventType.TOOL_CALL_CHUNK] = Field(EventType.TOOL_CALL_CHUNK, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
139+
type: Literal[EventType.TOOL_CALL_CHUNK] = EventType.TOOL_CALL_CHUNK # pyright: ignore[reportIncompatibleVariableOverride]
139140
tool_call_id: Optional[str] = None
140141
tool_call_name: Optional[str] = None
141142
parent_message_id: Optional[str] = None
@@ -168,31 +169,31 @@ class StateSnapshotEvent(BaseEvent):
168169
"""
169170
Event containing a snapshot of the state.
170171
"""
171-
type: Literal[EventType.STATE_SNAPSHOT] = Field(EventType.STATE_SNAPSHOT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
172+
type: Literal[EventType.STATE_SNAPSHOT] = EventType.STATE_SNAPSHOT # pyright: ignore[reportIncompatibleVariableOverride]
172173
snapshot: State
173174

174175

175176
class StateDeltaEvent(BaseEvent):
176177
"""
177178
Event containing a delta of the state.
178179
"""
179-
type: Literal[EventType.STATE_DELTA] = Field(EventType.STATE_DELTA, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
180+
type: Literal[EventType.STATE_DELTA] = EventType.STATE_DELTA # pyright: ignore[reportIncompatibleVariableOverride]
180181
delta: List[Any] # JSON Patch (RFC 6902)
181182

182183

183184
class MessagesSnapshotEvent(BaseEvent):
184185
"""
185186
Event containing a snapshot of the messages.
186187
"""
187-
type: Literal[EventType.MESSAGES_SNAPSHOT] = Field(EventType.MESSAGES_SNAPSHOT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
188+
type: Literal[EventType.MESSAGES_SNAPSHOT] = EventType.MESSAGES_SNAPSHOT # pyright: ignore[reportIncompatibleVariableOverride]
188189
messages: List[Message]
189190

190191

191192
class RawEvent(BaseEvent):
192193
"""
193194
Event containing a raw event.
194195
"""
195-
type: Literal[EventType.RAW] = Field(EventType.RAW, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
196+
type: Literal[EventType.RAW] = EventType.RAW # pyright: ignore[reportIncompatibleVariableOverride]
196197
event: Any
197198
source: Optional[str] = None
198199

@@ -201,7 +202,7 @@ class CustomEvent(BaseEvent):
201202
"""
202203
Event containing a custom event.
203204
"""
204-
type: Literal[EventType.CUSTOM] = Field(EventType.CUSTOM, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
205+
type: Literal[EventType.CUSTOM] = EventType.CUSTOM # pyright: ignore[reportIncompatibleVariableOverride]
205206
name: str
206207
value: Any
207208

@@ -210,7 +211,7 @@ class RunStartedEvent(BaseEvent):
210211
"""
211212
Event indicating that a run has started.
212213
"""
213-
type: Literal[EventType.RUN_STARTED] = Field(EventType.RUN_STARTED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
214+
type: Literal[EventType.RUN_STARTED] = EventType.RUN_STARTED # pyright: ignore[reportIncompatibleVariableOverride]
214215
thread_id: str
215216
run_id: str
216217

@@ -219,7 +220,7 @@ class RunFinishedEvent(BaseEvent):
219220
"""
220221
Event indicating that a run has finished.
221222
"""
222-
type: Literal[EventType.RUN_FINISHED] = Field(EventType.RUN_FINISHED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
223+
type: Literal[EventType.RUN_FINISHED] = EventType.RUN_FINISHED # pyright: ignore[reportIncompatibleVariableOverride]
223224
thread_id: str
224225
run_id: str
225226
result: Optional[Any] = None
@@ -229,7 +230,7 @@ class RunErrorEvent(BaseEvent):
229230
"""
230231
Event indicating that a run has encountered an error.
231232
"""
232-
type: Literal[EventType.RUN_ERROR] = Field(EventType.RUN_ERROR, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
233+
type: Literal[EventType.RUN_ERROR] = EventType.RUN_ERROR # pyright: ignore[reportIncompatibleVariableOverride]
233234
message: str
234235
code: Optional[str] = None
235236

@@ -238,15 +239,15 @@ class StepStartedEvent(BaseEvent):
238239
"""
239240
Event indicating that a step has started.
240241
"""
241-
type: Literal[EventType.STEP_STARTED] = Field(EventType.STEP_STARTED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
242+
type: Literal[EventType.STEP_STARTED] = EventType.STEP_STARTED # pyright: ignore[reportIncompatibleVariableOverride]
242243
step_name: str
243244

244245

245246
class StepFinishedEvent(BaseEvent):
246247
"""
247248
Event indicating that a step has finished.
248249
"""
249-
type: Literal[EventType.STEP_FINISHED] = Field(EventType.STEP_FINISHED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
250+
type: Literal[EventType.STEP_FINISHED] = EventType.STEP_FINISHED # pyright: ignore[reportIncompatibleVariableOverride]
250251
step_name: str
251252

252253

python-sdk/ag_ui/core/types.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,10 @@
22
This module contains the types for the Agent User Interaction Protocol Python SDK.
33
"""
44

5-
from enum import Enum
6-
from typing import Any, List, Literal, Optional, Union, Annotated
7-
from pydantic import BaseModel, Field, ConfigDict
8-
from pydantic.alias_generators import to_camel
9-
5+
from typing import Annotated, Any, List, Literal, Optional, Union
106

11-
class Role(str, Enum):
12-
"""
13-
The role of an actor.
14-
"""
15-
DEVELOPER = "developer"
16-
SYSTEM = "system"
17-
ASSISTANT = "assistant"
18-
USER = "user"
19-
TOOL = "tool"
7+
from pydantic import BaseModel, ConfigDict, Field
8+
from pydantic.alias_generators import to_camel
209

2110

2211
class ConfiguredBaseModel(BaseModel):
@@ -43,7 +32,7 @@ class ToolCall(ConfiguredBaseModel):
4332
A tool call, modelled after OpenAI tool calls.
4433
"""
4534
id: str
46-
type: Literal["function"] = Field("function", init=False) # pyright: ignore[reportIncompatibleVariableOverride]
35+
type: Literal["function"] = "function" # pyright: ignore[reportIncompatibleVariableOverride]
4736
function: FunctionCall
4837

4938

@@ -61,30 +50,31 @@ class DeveloperMessage(BaseMessage):
6150
"""
6251
A developer message.
6352
"""
64-
role: Literal[Role.DEVELOPER] = Field(Role.DEVELOPER, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
53+
role: Literal["developer"] = "developer" # pyright: ignore[reportIncompatibleVariableOverride]
6554
content: str
6655

6756

6857
class SystemMessage(BaseMessage):
6958
"""
7059
A system message.
7160
"""
72-
role: Literal[Role.SYSTEM] = Field(Role.SYSTEM, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
61+
role: Literal["system"] = "system" # pyright: ignore[reportIncompatibleVariableOverride]
7362
content: str
7463

7564

7665
class AssistantMessage(BaseMessage):
7766
"""
7867
An assistant message.
7968
"""
80-
role: Literal[Role.ASSISTANT] = Field(Role.ASSISTANT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
69+
role: Literal["assistant"] = "assistant" # pyright: ignore[reportIncompatibleVariableOverride]
8170
tool_calls: Optional[List[ToolCall]] = None
8271

72+
8373
class UserMessage(BaseMessage):
8474
"""
8575
A user message.
8676
"""
87-
role: Literal[Role.USER] = Field(Role.USER, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
77+
role: Literal["user"] = "user" # pyright: ignore[reportIncompatibleVariableOverride]
8878
content: str
8979

9080

@@ -93,7 +83,7 @@ class ToolMessage(ConfiguredBaseModel):
9383
A tool result message.
9484
"""
9585
id: str
96-
role: Literal[Role.TOOL] = Field(Role.TOOL, init=False)
86+
role: Literal["tool"] = "tool"
9787
content: str
9888
tool_call_id: str
9989

@@ -103,6 +93,8 @@ class ToolMessage(ConfiguredBaseModel):
10393
Field(discriminator="role")
10494
]
10595

96+
Role = Literal["developer", "system", "assistant", "user", "tool"]
97+
10698

10799
class Context(ConfiguredBaseModel):
108100
"""

0 commit comments

Comments
 (0)