Skip to content

Commit 049bb62

Browse files
committed
chore: remove init=False from default literals
Remove `init=False` from default literals to improve compatibility with existing code. Minimise changes by using literal strings instead of enum.
1 parent 9e12e6b commit 049bb62

File tree

2 files changed

+34
-41
lines changed

2 files changed

+34
-41
lines changed

python-sdk/ag_ui/core/events.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
"""
44

55
from enum import Enum
6-
from typing import Any, List, Literal, Optional, Union, Annotated
6+
from typing import Annotated, Any, List, Literal, Optional, Union
7+
78
from pydantic import Field
89

9-
from .types import Message, State, ConfiguredBaseModel, Role
10+
from .types import ConfiguredBaseModel, Message, State
1011

1112

1213
class EventType(str, Enum):
@@ -46,16 +47,16 @@ class TextMessageStartEvent(BaseEvent):
4647
"""
4748
Event indicating the start of a text message.
4849
"""
49-
type: Literal[EventType.TEXT_MESSAGE_START] = Field(EventType.TEXT_MESSAGE_START, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
50+
type: Literal[EventType.TEXT_MESSAGE_START] = EventType.TEXT_MESSAGE_START # pyright: ignore[reportIncompatibleVariableOverride]
5051
message_id: str
51-
role: Literal[Role.ASSISTANT] = Field(Role.ASSISTANT, init=False)
52+
role: Literal["assistant"] = "assistant"
5253

5354

5455
class TextMessageContentEvent(BaseEvent):
5556
"""
5657
Event containing a piece of text message content.
5758
"""
58-
type: Literal[EventType.TEXT_MESSAGE_CONTENT] = Field(EventType.TEXT_MESSAGE_CONTENT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
59+
type: Literal[EventType.TEXT_MESSAGE_CONTENT] = EventType.TEXT_MESSAGE_CONTENT # pyright: ignore[reportIncompatibleVariableOverride]
5960
message_id: str
6061
delta: str = Field(min_length=1)
6162

@@ -64,14 +65,14 @@ class TextMessageEndEvent(BaseEvent):
6465
"""
6566
Event indicating the end of a text message.
6667
"""
67-
type: Literal[EventType.TEXT_MESSAGE_END] = Field(EventType.TEXT_MESSAGE_END, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
68+
type: Literal[EventType.TEXT_MESSAGE_END] = EventType.TEXT_MESSAGE_END # pyright: ignore[reportIncompatibleVariableOverride]
6869
message_id: str
6970

7071
class TextMessageChunkEvent(BaseEvent):
7172
"""
7273
Event containing a chunk of text message content.
7374
"""
74-
type: Literal[EventType.TEXT_MESSAGE_CHUNK] = Field(EventType.TEXT_MESSAGE_CHUNK, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
75+
type: Literal[EventType.TEXT_MESSAGE_CHUNK] = EventType.TEXT_MESSAGE_CHUNK # pyright: ignore[reportIncompatibleVariableOverride]
7576
message_id: Optional[str] = None
7677
role: Optional[Literal["assistant"]] = None
7778
delta: Optional[str] = None
@@ -80,7 +81,7 @@ class ToolCallStartEvent(BaseEvent):
8081
"""
8182
Event indicating the start of a tool call.
8283
"""
83-
type: Literal[EventType.TOOL_CALL_START] = Field(EventType.TOOL_CALL_START, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
84+
type: Literal[EventType.TOOL_CALL_START] = EventType.TOOL_CALL_START # pyright: ignore[reportIncompatibleVariableOverride]
8485
tool_call_id: str
8586
tool_call_name: str
8687
parent_message_id: Optional[str] = None
@@ -90,7 +91,7 @@ class ToolCallArgsEvent(BaseEvent):
9091
"""
9192
Event containing tool call arguments.
9293
"""
93-
type: Literal[EventType.TOOL_CALL_ARGS] = Field(EventType.TOOL_CALL_ARGS, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
94+
type: Literal[EventType.TOOL_CALL_ARGS] = EventType.TOOL_CALL_ARGS # pyright: ignore[reportIncompatibleVariableOverride]
9495
tool_call_id: str
9596
delta: str
9697

@@ -99,14 +100,14 @@ class ToolCallEndEvent(BaseEvent):
99100
"""
100101
Event indicating the end of a tool call.
101102
"""
102-
type: Literal[EventType.TOOL_CALL_END] = Field(EventType.TOOL_CALL_END, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
103+
type: Literal[EventType.TOOL_CALL_END] = EventType.TOOL_CALL_END # pyright: ignore[reportIncompatibleVariableOverride]
103104
tool_call_id: str
104105

105106
class ToolCallChunkEvent(BaseEvent):
106107
"""
107108
Event containing a chunk of tool call content.
108109
"""
109-
type: Literal[EventType.TOOL_CALL_CHUNK] = Field(EventType.TOOL_CALL_CHUNK, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
110+
type: Literal[EventType.TOOL_CALL_CHUNK] = EventType.TOOL_CALL_CHUNK # pyright: ignore[reportIncompatibleVariableOverride]
110111
tool_call_id: Optional[str] = None
111112
tool_call_name: Optional[str] = None
112113
parent_message_id: Optional[str] = None
@@ -116,31 +117,31 @@ class StateSnapshotEvent(BaseEvent):
116117
"""
117118
Event containing a snapshot of the state.
118119
"""
119-
type: Literal[EventType.STATE_SNAPSHOT] = Field(EventType.STATE_SNAPSHOT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
120+
type: Literal[EventType.STATE_SNAPSHOT] = EventType.STATE_SNAPSHOT # pyright: ignore[reportIncompatibleVariableOverride]
120121
snapshot: State
121122

122123

123124
class StateDeltaEvent(BaseEvent):
124125
"""
125126
Event containing a delta of the state.
126127
"""
127-
type: Literal[EventType.STATE_DELTA] = Field(EventType.STATE_DELTA, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
128+
type: Literal[EventType.STATE_DELTA] = EventType.STATE_DELTA # pyright: ignore[reportIncompatibleVariableOverride]
128129
delta: List[Any] # JSON Patch (RFC 6902)
129130

130131

131132
class MessagesSnapshotEvent(BaseEvent):
132133
"""
133134
Event containing a snapshot of the messages.
134135
"""
135-
type: Literal[EventType.MESSAGES_SNAPSHOT] = Field(EventType.MESSAGES_SNAPSHOT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
136+
type: Literal[EventType.MESSAGES_SNAPSHOT] = EventType.MESSAGES_SNAPSHOT # pyright: ignore[reportIncompatibleVariableOverride]
136137
messages: List[Message]
137138

138139

139140
class RawEvent(BaseEvent):
140141
"""
141142
Event containing a raw event.
142143
"""
143-
type: Literal[EventType.RAW] = Field(EventType.RAW, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
144+
type: Literal[EventType.RAW] = EventType.RAW # pyright: ignore[reportIncompatibleVariableOverride]
144145
event: Any
145146
source: Optional[str] = None
146147

@@ -149,7 +150,7 @@ class CustomEvent(BaseEvent):
149150
"""
150151
Event containing a custom event.
151152
"""
152-
type: Literal[EventType.CUSTOM] = Field(EventType.CUSTOM, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
153+
type: Literal[EventType.CUSTOM] = EventType.CUSTOM # pyright: ignore[reportIncompatibleVariableOverride]
153154
name: str
154155
value: Any
155156

@@ -158,7 +159,7 @@ class RunStartedEvent(BaseEvent):
158159
"""
159160
Event indicating that a run has started.
160161
"""
161-
type: Literal[EventType.RUN_STARTED] = Field(EventType.RUN_STARTED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
162+
type: Literal[EventType.RUN_STARTED] = EventType.RUN_STARTED # pyright: ignore[reportIncompatibleVariableOverride]
162163
thread_id: str
163164
run_id: str
164165

@@ -167,7 +168,7 @@ class RunFinishedEvent(BaseEvent):
167168
"""
168169
Event indicating that a run has finished.
169170
"""
170-
type: Literal[EventType.RUN_FINISHED] = Field(EventType.RUN_FINISHED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
171+
type: Literal[EventType.RUN_FINISHED] = EventType.RUN_FINISHED # pyright: ignore[reportIncompatibleVariableOverride]
171172
thread_id: str
172173
run_id: str
173174

@@ -176,7 +177,7 @@ class RunErrorEvent(BaseEvent):
176177
"""
177178
Event indicating that a run has encountered an error.
178179
"""
179-
type: Literal[EventType.RUN_ERROR] = Field(EventType.RUN_ERROR, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
180+
type: Literal[EventType.RUN_ERROR] = EventType.RUN_ERROR # pyright: ignore[reportIncompatibleVariableOverride]
180181
message: str
181182
code: Optional[str] = None
182183

@@ -185,15 +186,15 @@ class StepStartedEvent(BaseEvent):
185186
"""
186187
Event indicating that a step has started.
187188
"""
188-
type: Literal[EventType.STEP_STARTED] = Field(EventType.STEP_STARTED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
189+
type: Literal[EventType.STEP_STARTED] = EventType.STEP_STARTED # pyright: ignore[reportIncompatibleVariableOverride]
189190
step_name: str
190191

191192

192193
class StepFinishedEvent(BaseEvent):
193194
"""
194195
Event indicating that a step has finished.
195196
"""
196-
type: Literal[EventType.STEP_FINISHED] = Field(EventType.STEP_FINISHED, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
197+
type: Literal[EventType.STEP_FINISHED] = EventType.STEP_FINISHED # pyright: ignore[reportIncompatibleVariableOverride]
197198
step_name: str
198199

199200

python-sdk/ag_ui/core/types.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,10 @@
22
This module contains the types for the Agent User Interaction Protocol Python SDK.
33
"""
44

5-
from enum import Enum
6-
from typing import Any, List, Literal, Optional, Union, Annotated
7-
from pydantic import BaseModel, Field, ConfigDict
8-
from pydantic.alias_generators import to_camel
9-
5+
from typing import Annotated, Any, List, Literal, Optional, Union
106

11-
class Role(str, Enum):
12-
"""
13-
The role of an actor.
14-
"""
15-
DEVELOPER = "developer"
16-
SYSTEM = "system"
17-
ASSISTANT = "assistant"
18-
USER = "user"
19-
TOOL = "tool"
7+
from pydantic import BaseModel, ConfigDict, Field
8+
from pydantic.alias_generators import to_camel
209

2110

2211
class ConfiguredBaseModel(BaseModel):
@@ -44,7 +33,7 @@ class ToolCall(ConfiguredBaseModel):
4433
A tool call, modelled after OpenAI tool calls.
4534
"""
4635
id: str
47-
type: Literal["function"] = Field("function", init=False) # pyright: ignore[reportIncompatibleVariableOverride]
36+
type: Literal["function"] = "function" # pyright: ignore[reportIncompatibleVariableOverride]
4837
function: FunctionCall
4938

5039

@@ -62,30 +51,31 @@ class DeveloperMessage(BaseMessage):
6251
"""
6352
A developer message.
6453
"""
65-
role: Literal[Role.DEVELOPER] = Field(Role.DEVELOPER, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
54+
role: Literal["developer"] = "developer" # pyright: ignore[reportIncompatibleVariableOverride]
6655
content: str
6756

6857

6958
class SystemMessage(BaseMessage):
7059
"""
7160
A system message.
7261
"""
73-
role: Literal[Role.SYSTEM] = Field(Role.SYSTEM, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
62+
role: Literal["system"] = "system" # pyright: ignore[reportIncompatibleVariableOverride]
7463
content: str
7564

7665

7766
class AssistantMessage(BaseMessage):
7867
"""
7968
An assistant message.
8069
"""
81-
role: Literal[Role.ASSISTANT] = Field(Role.ASSISTANT, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
70+
role: Literal["assistant"] = "assistant" # pyright: ignore[reportIncompatibleVariableOverride]
8271
tool_calls: Optional[List[ToolCall]] = None
8372

73+
8474
class UserMessage(BaseMessage):
8575
"""
8676
A user message.
8777
"""
88-
role: Literal[Role.USER] = Field(Role.USER, init=False) # pyright: ignore[reportIncompatibleVariableOverride]
78+
role: Literal["user"] = "user" # pyright: ignore[reportIncompatibleVariableOverride]
8979
content: str
9080

9181

@@ -94,7 +84,7 @@ class ToolMessage(ConfiguredBaseModel):
9484
A tool result message.
9585
"""
9686
id: str
97-
role: Literal[Role.TOOL] = Field(Role.TOOL, init=False)
87+
role: Literal["tool"] = "tool"
9888
content: str
9989
tool_call_id: str
10090

@@ -104,6 +94,8 @@ class ToolMessage(ConfiguredBaseModel):
10494
Field(discriminator="role")
10595
]
10696

97+
Role = Literal["developer", "system", "assistant", "user", "tool"]
98+
10799

108100
class Context(ConfiguredBaseModel):
109101
"""

0 commit comments

Comments
 (0)