diff --git a/python-sdk/ag_ui/core/events.py b/python-sdk/ag_ui/core/events.py index edc483a07..40b655a78 100644 --- a/python-sdk/ag_ui/core/events.py +++ b/python-sdk/ag_ui/core/events.py @@ -3,10 +3,11 @@ """ from enum import Enum -from typing import Any, List, Literal, Optional, Union, Annotated +from typing import Annotated, Any, List, Literal, Optional, Union + from pydantic import Field -from .types import Message, State, ConfiguredBaseModel +from .types import ConfiguredBaseModel, Message, State class EventType(str, Enum): @@ -46,36 +47,32 @@ class TextMessageStartEvent(BaseEvent): """ Event indicating the start of a text message. """ - type: Literal[EventType.TEXT_MESSAGE_START] + type: Literal[EventType.TEXT_MESSAGE_START] = EventType.TEXT_MESSAGE_START # pyright: ignore[reportIncompatibleVariableOverride] message_id: str - role: Literal["assistant"] + role: Literal["assistant"] = "assistant" class TextMessageContentEvent(BaseEvent): """ Event containing a piece of text message content. """ - type: Literal[EventType.TEXT_MESSAGE_CONTENT] + type: Literal[EventType.TEXT_MESSAGE_CONTENT] = EventType.TEXT_MESSAGE_CONTENT # pyright: ignore[reportIncompatibleVariableOverride] message_id: str - delta: str # This should not be an empty string - - def model_post_init(self, __context): - if len(self.delta) == 0: - raise ValueError("Delta must not be an empty string") + delta: str = Field(min_length=1) class TextMessageEndEvent(BaseEvent): """ Event indicating the end of a text message. """ - type: Literal[EventType.TEXT_MESSAGE_END] + type: Literal[EventType.TEXT_MESSAGE_END] = EventType.TEXT_MESSAGE_END # pyright: ignore[reportIncompatibleVariableOverride] message_id: str class TextMessageChunkEvent(BaseEvent): """ Event containing a chunk of text message content. """ - type: Literal[EventType.TEXT_MESSAGE_CHUNK] + type: Literal[EventType.TEXT_MESSAGE_CHUNK] = EventType.TEXT_MESSAGE_CHUNK # pyright: ignore[reportIncompatibleVariableOverride] message_id: Optional[str] = None role: Optional[Literal["assistant"]] = None delta: Optional[str] = None @@ -84,7 +81,7 @@ class ToolCallStartEvent(BaseEvent): """ Event indicating the start of a tool call. """ - type: Literal[EventType.TOOL_CALL_START] + type: Literal[EventType.TOOL_CALL_START] = EventType.TOOL_CALL_START # pyright: ignore[reportIncompatibleVariableOverride] tool_call_id: str tool_call_name: str parent_message_id: Optional[str] = None @@ -94,7 +91,7 @@ class ToolCallArgsEvent(BaseEvent): """ Event containing tool call arguments. """ - type: Literal[EventType.TOOL_CALL_ARGS] + type: Literal[EventType.TOOL_CALL_ARGS] = EventType.TOOL_CALL_ARGS # pyright: ignore[reportIncompatibleVariableOverride] tool_call_id: str delta: str @@ -103,14 +100,14 @@ class ToolCallEndEvent(BaseEvent): """ Event indicating the end of a tool call. """ - type: Literal[EventType.TOOL_CALL_END] + type: Literal[EventType.TOOL_CALL_END] = EventType.TOOL_CALL_END # pyright: ignore[reportIncompatibleVariableOverride] tool_call_id: str class ToolCallChunkEvent(BaseEvent): """ Event containing a chunk of tool call content. """ - type: Literal[EventType.TOOL_CALL_CHUNK] + type: Literal[EventType.TOOL_CALL_CHUNK] = EventType.TOOL_CALL_CHUNK # pyright: ignore[reportIncompatibleVariableOverride] tool_call_id: Optional[str] = None tool_call_name: Optional[str] = None parent_message_id: Optional[str] = None @@ -120,7 +117,7 @@ class StateSnapshotEvent(BaseEvent): """ Event containing a snapshot of the state. """ - type: Literal[EventType.STATE_SNAPSHOT] + type: Literal[EventType.STATE_SNAPSHOT] = EventType.STATE_SNAPSHOT # pyright: ignore[reportIncompatibleVariableOverride] snapshot: State @@ -128,7 +125,7 @@ class StateDeltaEvent(BaseEvent): """ Event containing a delta of the state. """ - type: Literal[EventType.STATE_DELTA] + type: Literal[EventType.STATE_DELTA] = EventType.STATE_DELTA # pyright: ignore[reportIncompatibleVariableOverride] delta: List[Any] # JSON Patch (RFC 6902) @@ -136,7 +133,7 @@ class MessagesSnapshotEvent(BaseEvent): """ Event containing a snapshot of the messages. """ - type: Literal[EventType.MESSAGES_SNAPSHOT] + type: Literal[EventType.MESSAGES_SNAPSHOT] = EventType.MESSAGES_SNAPSHOT # pyright: ignore[reportIncompatibleVariableOverride] messages: List[Message] @@ -144,7 +141,7 @@ class RawEvent(BaseEvent): """ Event containing a raw event. """ - type: Literal[EventType.RAW] + type: Literal[EventType.RAW] = EventType.RAW # pyright: ignore[reportIncompatibleVariableOverride] event: Any source: Optional[str] = None @@ -153,7 +150,7 @@ class CustomEvent(BaseEvent): """ Event containing a custom event. """ - type: Literal[EventType.CUSTOM] + type: Literal[EventType.CUSTOM] = EventType.CUSTOM # pyright: ignore[reportIncompatibleVariableOverride] name: str value: Any @@ -162,7 +159,7 @@ class RunStartedEvent(BaseEvent): """ Event indicating that a run has started. """ - type: Literal[EventType.RUN_STARTED] + type: Literal[EventType.RUN_STARTED] = EventType.RUN_STARTED # pyright: ignore[reportIncompatibleVariableOverride] thread_id: str run_id: str @@ -171,7 +168,7 @@ class RunFinishedEvent(BaseEvent): """ Event indicating that a run has finished. """ - type: Literal[EventType.RUN_FINISHED] + type: Literal[EventType.RUN_FINISHED] = EventType.RUN_FINISHED # pyright: ignore[reportIncompatibleVariableOverride] thread_id: str run_id: str @@ -180,7 +177,7 @@ class RunErrorEvent(BaseEvent): """ Event indicating that a run has encountered an error. """ - type: Literal[EventType.RUN_ERROR] + type: Literal[EventType.RUN_ERROR] = EventType.RUN_ERROR # pyright: ignore[reportIncompatibleVariableOverride] message: str code: Optional[str] = None @@ -189,7 +186,7 @@ class StepStartedEvent(BaseEvent): """ Event indicating that a step has started. """ - type: Literal[EventType.STEP_STARTED] + type: Literal[EventType.STEP_STARTED] = EventType.STEP_STARTED # pyright: ignore[reportIncompatibleVariableOverride] step_name: str @@ -197,7 +194,7 @@ class StepFinishedEvent(BaseEvent): """ Event indicating that a step has finished. """ - type: Literal[EventType.STEP_FINISHED] + type: Literal[EventType.STEP_FINISHED] = EventType.STEP_FINISHED # pyright: ignore[reportIncompatibleVariableOverride] step_name: str diff --git a/python-sdk/ag_ui/core/types.py b/python-sdk/ag_ui/core/types.py index add219d44..61c5f30dd 100644 --- a/python-sdk/ag_ui/core/types.py +++ b/python-sdk/ag_ui/core/types.py @@ -2,10 +2,12 @@ This module contains the types for the Agent User Interaction Protocol Python SDK. """ -from typing import Any, List, Literal, Optional, Union, Annotated -from pydantic import BaseModel, Field, ConfigDict +from typing import Annotated, Any, List, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field from pydantic.alias_generators import to_camel + class ConfiguredBaseModel(BaseModel): """ A configurable base model. @@ -31,7 +33,7 @@ class ToolCall(ConfiguredBaseModel): A tool call, modelled after OpenAI tool calls. """ id: str - type: Literal["function"] + type: Literal["function"] = "function" # pyright: ignore[reportIncompatibleVariableOverride] function: FunctionCall @@ -49,7 +51,7 @@ class DeveloperMessage(BaseMessage): """ A developer message. """ - role: Literal["developer"] + role: Literal["developer"] = "developer" # pyright: ignore[reportIncompatibleVariableOverride] content: str @@ -57,7 +59,7 @@ class SystemMessage(BaseMessage): """ A system message. """ - role: Literal["system"] + role: Literal["system"] = "system" # pyright: ignore[reportIncompatibleVariableOverride] content: str @@ -65,7 +67,7 @@ class AssistantMessage(BaseMessage): """ An assistant message. """ - role: Literal["assistant"] + role: Literal["assistant"] = "assistant" # pyright: ignore[reportIncompatibleVariableOverride] tool_calls: Optional[List[ToolCall]] = None @@ -73,7 +75,7 @@ class UserMessage(BaseMessage): """ A user message. """ - role: Literal["user"] + role: Literal["user"] = "user" # pyright: ignore[reportIncompatibleVariableOverride] content: str @@ -82,7 +84,7 @@ class ToolMessage(ConfiguredBaseModel): A tool result message. """ id: str - role: Literal["tool"] + role: Literal["tool"] = "tool" content: str tool_call_id: str @@ -92,7 +94,6 @@ class ToolMessage(ConfiguredBaseModel): Field(discriminator="role") ] - Role = Literal["developer", "system", "assistant", "user", "tool"] diff --git a/python-sdk/tests/test_encoder.py b/python-sdk/tests/test_encoder.py index ef1153148..2d466c5a4 100644 --- a/python-sdk/tests/test_encoder.py +++ b/python-sdk/tests/test_encoder.py @@ -43,7 +43,6 @@ def test_encode_sse_method(self): """Test the encode_sse method""" # Create a test event with specific data event = TextMessageContentEvent( - type=EventType.TEXT_MESSAGE_CONTENT, message_id="msg_123", delta="Hello, world!", timestamp=1648214400000 @@ -83,7 +82,6 @@ def test_encode_with_different_event_types(self): # Test with a more complex event content_event = TextMessageContentEvent( - type=EventType.TEXT_MESSAGE_CONTENT, message_id="msg_456", delta="Testing different events", timestamp=1648214400000 @@ -130,7 +128,6 @@ def test_null_value_exclusion(self): # Test with another event that has optional fields # Create event with some optional fields set to None event_with_optional = ToolCallStartEvent( - type=EventType.TOOL_CALL_START, tool_call_id="call_123", tool_call_name="test_tool", parent_message_id=None, # Optional field explicitly set to None @@ -152,7 +149,6 @@ def test_round_trip_serialization(self): """Test that events can be serialized to JSON with camelCase and deserialized back correctly""" # Create a complex event with multiple fields original_event = ToolCallStartEvent( - type=EventType.TOOL_CALL_START, tool_call_id="call_abc123", tool_call_name="search_tool", parent_message_id="msg_parent_456", diff --git a/python-sdk/tests/test_events.py b/python-sdk/tests/test_events.py index 929b543a2..c73a2537c 100644 --- a/python-sdk/tests/test_events.py +++ b/python-sdk/tests/test_events.py @@ -49,9 +49,7 @@ def test_base_event_creation(self): def test_text_message_start(self): """Test creating and serializing a TextMessageStartEvent event""" event = TextMessageStartEvent( - type=EventType.TEXT_MESSAGE_START, message_id="msg_123", - role="assistant", timestamp=1648214400000 ) self.assertEqual(event.message_id, "msg_123") @@ -66,7 +64,6 @@ def test_text_message_start(self): def test_text_message_content(self): """Test creating and serializing a TextMessageContentEvent event""" event = TextMessageContentEvent( - type=EventType.TEXT_MESSAGE_CONTENT, message_id="msg_123", delta="Hello, world!", timestamp=1648214400000 @@ -83,7 +80,6 @@ def test_text_message_content(self): def test_text_message_end(self): """Test creating and serializing a TextMessageEndEvent event""" event = TextMessageEndEvent( - type=EventType.TEXT_MESSAGE_END, message_id="msg_123", timestamp=1648214400000 ) @@ -97,7 +93,6 @@ def test_text_message_end(self): def test_tool_call_start(self): """Test creating and serializing a ToolCallStartEvent event""" event = ToolCallStartEvent( - type=EventType.TOOL_CALL_START, tool_call_id="call_123", tool_call_name="get_weather", parent_message_id="msg_456", @@ -117,7 +112,6 @@ def test_tool_call_start(self): def test_tool_call_args(self): """Test creating and serializing a ToolCallArgsEvent event""" event = ToolCallArgsEvent( - type=EventType.TOOL_CALL_ARGS, tool_call_id="call_123", delta='{"location": "New York"}', timestamp=1648214400000 @@ -134,7 +128,6 @@ def test_tool_call_args(self): def test_tool_call_end(self): """Test creating and serializing a ToolCallEndEvent event""" event = ToolCallEndEvent( - type=EventType.TOOL_CALL_END, tool_call_id="call_123", timestamp=1648214400000 ) @@ -149,7 +142,6 @@ def test_state_snapshot(self): """Test creating and serializing a StateSnapshotEvent event""" state = {"conversation_state": "active", "user_info": {"name": "John"}} event = StateSnapshotEvent( - type=EventType.STATE_SNAPSHOT, snapshot=state, timestamp=1648214400000 ) @@ -169,7 +161,6 @@ def test_state_delta(self): {"op": "add", "path": "/user_info/age", "value": 30} ] event = StateDeltaEvent( - type=EventType.STATE_DELTA, delta=delta, timestamp=1648214400000 ) @@ -185,11 +176,10 @@ def test_state_delta(self): def test_messages_snapshot(self): """Test creating and serializing a MessagesSnapshotEvent event""" messages = [ - UserMessage(id="user_1", role="user", content="Hello"), - AssistantMessage(id="asst_1", role="assistant", content="Hi there", tool_calls=[ + UserMessage(id="user_1", content="Hello"), + AssistantMessage(id="asst_1", content="Hi there", tool_calls=[ ToolCall( id="call_1", - type="function", function=FunctionCall( name="get_weather", arguments='{"location": "New York"}' @@ -198,7 +188,6 @@ def test_messages_snapshot(self): ]) ] event = MessagesSnapshotEvent( - type=EventType.MESSAGES_SNAPSHOT, messages=messages, timestamp=1648214400000 ) @@ -217,7 +206,6 @@ def test_raw_event(self): """Test creating and serializing a RawEvent""" raw_data = {"origin": "server", "data": {"key": "value"}} event = RawEvent( - type=EventType.RAW, event=raw_data, source="api", timestamp=1648214400000 @@ -234,7 +222,6 @@ def test_raw_event(self): def test_custom_event(self): """Test creating and serializing a CustomEvent""" event = CustomEvent( - type=EventType.CUSTOM, name="user_action", value={"action": "click", "element": "button"}, timestamp=1648214400000 @@ -251,7 +238,6 @@ def test_custom_event(self): def test_run_started(self): """Test creating and serializing a RunStartedEvent event""" event = RunStartedEvent( - type=EventType.RUN_STARTED, thread_id="thread_123", run_id="run_456", timestamp=1648214400000 @@ -268,7 +254,6 @@ def test_run_started(self): def test_run_finished(self): """Test creating and serializing a RunFinishedEvent event""" event = RunFinishedEvent( - type=EventType.RUN_FINISHED, thread_id="thread_123", run_id="run_456", timestamp=1648214400000 @@ -285,7 +270,6 @@ def test_run_finished(self): def test_run_error(self): """Test creating and serializing a RunErrorEvent event""" event = RunErrorEvent( - type=EventType.RUN_ERROR, message="An error occurred during execution", code="ERROR_001", timestamp=1648214400000 @@ -302,7 +286,6 @@ def test_run_error(self): def test_step_started(self): """Test creating and serializing a StepStartedEvent event""" event = StepStartedEvent( - type=EventType.STEP_STARTED, step_name="process_data", timestamp=1648214400000 ) @@ -316,7 +299,6 @@ def test_step_started(self): def test_step_finished(self): """Test creating and serializing a StepFinishedEvent event""" event = StepFinishedEvent( - type=EventType.STEP_FINISHED, step_name="process_data", timestamp=1648214400000 ) @@ -383,57 +365,34 @@ def test_validation_constraints(self): # TextMessageContentEvent delta cannot be empty with self.assertRaises(ValueError): TextMessageContentEvent( - type=EventType.TEXT_MESSAGE_CONTENT, message_id="msg_123", delta="" # Empty delta, should fail ) - - # TextMessageStartEvent role must be "assistant" - with self.assertRaises(ValidationError): - TextMessageStartEvent( - type=EventType.TEXT_MESSAGE_START, - message_id="msg_123", - role="user" # Invalid role, should be "assistant" - ) - - # Event type must match the class - with self.assertRaises(ValidationError): - TextMessageEndEvent( - type=EventType.TEXT_MESSAGE_START, # Wrong event type - message_id="msg_123" - ) def test_serialization_round_trip(self): """Test serialization and deserialization for different event types""" # Create events of different types events = [ TextMessageStartEvent( - type=EventType.TEXT_MESSAGE_START, message_id="msg_123", - role="assistant" ), TextMessageContentEvent( - type=EventType.TEXT_MESSAGE_CONTENT, message_id="msg_123", delta="Hello, world!" ), ToolCallStartEvent( - type=EventType.TOOL_CALL_START, tool_call_id="call_123", tool_call_name="get_weather" ), StateSnapshotEvent( - type=EventType.STATE_SNAPSHOT, snapshot={"status": "active"} ), MessagesSnapshotEvent( - type=EventType.MESSAGES_SNAPSHOT, messages=[ - UserMessage(id="user_1", role="user", content="Hello") + UserMessage(id="user_1", content="Hello") ] ), RunStartedEvent( - type=EventType.RUN_STARTED, thread_id="thread_123", run_id="run_456" ) @@ -475,7 +434,6 @@ def test_serialization_round_trip(self): def test_raw_event_with_null_source(self): """Test RawEvent with null source""" event = RawEvent( - type=EventType.RAW, event={"data": "test"}, source=None # Explicit None ) @@ -522,7 +480,6 @@ def test_complex_nested_event_structures(self): } event = StateSnapshotEvent( - type=EventType.STATE_SNAPSHOT, snapshot=complex_state, timestamp=1648214400000 ) @@ -551,7 +508,6 @@ def test_event_with_unicode_and_special_chars(self): text = "Hello 你好 こんにちは 안녕하세요 👋 🌍 \n\t\"'\\/<>{}[]" event = TextMessageContentEvent( - type=EventType.TEXT_MESSAGE_CONTENT, message_id="msg_unicode", delta=text, timestamp=1648214400000 diff --git a/python-sdk/tests/test_types.py b/python-sdk/tests/test_types.py index a71a232e5..e534aa5ab 100644 --- a/python-sdk/tests/test_types.py +++ b/python-sdk/tests/test_types.py @@ -28,7 +28,6 @@ def test_message_serialization(self): """Test serialization of a basic message""" user_msg = UserMessage( id="msg_123", - role="user", content="Hello, world!" ) serialized = user_msg.model_dump(by_alias=True) @@ -40,7 +39,6 @@ def test_tool_call_serialization(self): """Test camel case serialization for ConfiguredBaseModel subclasses""" tool_call = ToolCall( id="call_123", - type="function", function=FunctionCall(name="test_function", arguments="{}") ) serialized = tool_call.model_dump(by_alias=True) @@ -51,7 +49,6 @@ def test_tool_message_camel_case(self): """Test camel case serialization for ToolMessage""" tool_msg = ToolMessage( id="tool_123", - role="tool", content="Tool result", tool_call_id="call_456" ) @@ -103,7 +100,6 @@ def test_developer_message(self): """Test creating and serializing a developer message""" msg = DeveloperMessage( id="dev_123", - role="developer", content="Developer note" ) serialized = msg.model_dump(by_alias=True) @@ -114,7 +110,6 @@ def test_system_message(self): """Test creating and serializing a system message""" msg = SystemMessage( id="sys_123", - role="system", content="System instruction" ) serialized = msg.model_dump(by_alias=True) @@ -125,12 +120,10 @@ def test_assistant_message(self): """Test creating and serializing an assistant message with tool calls""" tool_call = ToolCall( id="call_456", - type="function", function=FunctionCall(name="get_data", arguments='{"param": "value"}') ) msg = AssistantMessage( id="asst_123", - role="assistant", content="Assistant response", tool_calls=[tool_call] ) @@ -144,7 +137,6 @@ def test_user_message(self): """Test creating and serializing a user message""" msg = UserMessage( id="user_123", - role="user", content="User query" ) serialized = msg.model_dump(by_alias=True)