diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index afd3407620b7..1ebe658c18e4 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -115,7 +115,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No ) # Append all messages to thread - self._message_thread.extend(message.messages) + await self.update_message_thread(message.messages) # Check termination condition after processing all messages if await self._apply_termination_condition(message.messages): @@ -139,6 +139,9 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No cancellation_token=ctx.cancellation_token, ) + async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: + self._message_thread.extend(messages) + @event async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: try: @@ -146,10 +149,9 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess delta: List[BaseAgentEvent | BaseChatMessage] = [] if message.agent_response.inner_messages is not None: for inner_message in message.agent_response.inner_messages: - self._message_thread.append(inner_message) delta.append(inner_message) - self._message_thread.append(message.agent_response.chat_message) delta.append(message.agent_response.chat_message) + await self.update_message_thread(delta) # Check if the conversation should be terminated. if await self._apply_termination_condition(delta, increment_turn_count=True): diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py index 78bf3929a046..34d1df7cf948 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -191,7 +191,7 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess if message.agent_response.inner_messages is not None: for inner_message in message.agent_response.inner_messages: delta.append(inner_message) - self._message_thread.append(message.agent_response.chat_message) + await self.update_message_thread([message.agent_response.chat_message]) delta.append(message.agent_response.chat_message) if self._termination_condition is not None: @@ -263,7 +263,7 @@ async def _reenter_outer_loop(self, cancellation_token: CancellationToken) -> No ) # Save my copy - self._message_thread.append(ledger_message) + await self.update_message_thread([ledger_message]) # Log it to the output topic. await self.publish_message( @@ -376,7 +376,7 @@ async def _orchestrate_step(self, cancellation_token: CancellationToken) -> None # Broadcast the next step message = TextMessage(content=progress_ledger["instruction_or_question"]["answer"], source=self._name) - self._message_thread.append(message) # My copy + await self.update_message_thread([message]) # My copy await self._log_message(f"Next Speaker: {progress_ledger['next_speaker']['answer']}") # Log it to the output topic. @@ -458,7 +458,7 @@ async def _prepare_final_answer(self, reason: str, cancellation_token: Cancellat assert isinstance(response.content, str) message = TextMessage(content=response.content, source=self._name) - self._message_thread.append(message) # My copy + await self.update_message_thread([message]) # My copy # Log it to the output topic. await self.publish_message( diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index c43ab85b8bca..d54b7aea298f 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -4,11 +4,16 @@ from inspect import iscoroutinefunction from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast -from autogen_core import AgentRuntime, Component, ComponentModel +from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel +from autogen_core.model_context import ( + ChatCompletionContext, + UnboundedChatCompletionContext, +) from autogen_core.models import ( AssistantMessage, ChatCompletionClient, CreateResult, + LLMMessage, ModelFamily, SystemMessage, UserMessage, @@ -22,6 +27,7 @@ from ...messages import ( BaseAgentEvent, BaseChatMessage, + HandoffMessage, MessageFactory, ModelClientStreamingChunkEvent, SelectorEvent, @@ -65,6 +71,7 @@ def __init__( max_selector_attempts: int, candidate_func: Optional[CandidateFuncType], emit_team_events: bool, + model_context: ChatCompletionContext | None, model_client_streaming: bool = False, ) -> None: super().__init__( @@ -90,6 +97,11 @@ def __init__( self._candidate_func = candidate_func self._is_candidate_func_async = iscoroutinefunction(self._candidate_func) self._model_client_streaming = model_client_streaming + if model_context is not None: + self._model_context = model_context + else: + self._model_context = UnboundedChatCompletionContext() + self._cancellation_token = CancellationToken() async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: pass @@ -97,6 +109,7 @@ async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> async def reset(self) -> None: self._current_turn = 0 self._message_thread.clear() + await self._model_context.clear() if self._termination_condition is not None: await self._termination_condition.reset() self._previous_speaker = None @@ -112,16 +125,37 @@ async def save_state(self) -> Mapping[str, Any]: async def load_state(self, state: Mapping[str, Any]) -> None: selector_state = SelectorManagerState.model_validate(state) self._message_thread = [self._message_factory.create(msg) for msg in selector_state.message_thread] + await self._add_messages_to_context( + self._model_context, [msg for msg in self._message_thread if isinstance(msg, BaseChatMessage)] + ) self._current_turn = selector_state.current_turn self._previous_speaker = selector_state.previous_speaker + @staticmethod + async def _add_messages_to_context( + model_context: ChatCompletionContext, + messages: Sequence[BaseChatMessage], + ) -> None: + """ + Add incoming messages to the model context. + """ + for msg in messages: + if isinstance(msg, HandoffMessage): + for llm_msg in msg.context: + await model_context.add_message(llm_msg) + await model_context.add_message(msg.to_model_message()) + + async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: + self._message_thread.extend(messages) + base_chat_messages = [m for m in messages if isinstance(m, BaseChatMessage)] + await self._add_messages_to_context(self._model_context, base_chat_messages) + async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str: """Selects the next speaker in a group chat using a ChatCompletion client, with the selector function as override if it returns a speaker name. A key assumption is that the agent type is the same as the topic type, which we use as the agent name. """ - # Use the selector function if provided. if self._selector_func is not None: if self._is_selector_func_async: @@ -163,18 +197,6 @@ async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) - assert len(participants) > 0 - # Construct the history of the conversation. - history_messages: List[str] = [] - for msg in thread: - if not isinstance(msg, BaseChatMessage): - # Only process chat messages. - continue - message = f"{msg.source}: {msg.to_model_text()}" - history_messages.append( - message.rstrip() + "\n\n" - ) # Create some consistency for how messages are separated in the transcript - history = "\n".join(history_messages) - # Construct agent roles. # Each agent sould appear on a single line. roles = "" @@ -184,17 +206,34 @@ async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) - # Select the next speaker. if len(participants) > 1: - agent_name = await self._select_speaker(roles, participants, history, self._max_selector_attempts) + agent_name = await self._select_speaker(roles, participants, self._max_selector_attempts) else: agent_name = participants[0] self._previous_speaker = agent_name trace_logger.debug(f"Selected speaker: {agent_name}") return agent_name - async def _select_speaker(self, roles: str, participants: List[str], history: str, max_attempts: int) -> str: + def construct_message_history(self, message_history: List[LLMMessage]) -> str: + # Construct the history of the conversation. + history_messages: List[str] = [] + for msg in message_history: + if isinstance(msg, UserMessage) or isinstance(msg, AssistantMessage): + message = f"{msg.source}: {msg.content}" + history_messages.append( + message.rstrip() + "\n\n" + ) # Create some consistency for how messages are separated in the transcript + + history: str = "\n".join(history_messages) + return history + + async def _select_speaker(self, roles: str, participants: List[str], max_attempts: int) -> str: + model_context_messages = await self._model_context.get_messages() + model_context_history = self.construct_message_history(model_context_messages) + select_speaker_prompt = self._selector_prompt.format( - roles=roles, participants=str(participants), history=history + roles=roles, participants=str(participants), history=model_context_history ) + select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage] if ModelFamily.is_openai(self._model_client.model_info["family"]): select_speaker_messages = [SystemMessage(content=select_speaker_prompt)] @@ -312,6 +351,7 @@ class SelectorGroupChatConfig(BaseModel): max_selector_attempts: int = 3 emit_team_events: bool = False model_client_streaming: bool = False + model_context: ComponentModel | None = None class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): @@ -349,6 +389,8 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): Make sure your custom message types are subclasses of :class:`~autogen_agentchat.messages.BaseAgentEvent` or :class:`~autogen_agentchat.messages.BaseChatMessage`. emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False. model_client_streaming (bool, optional): Whether to use streaming for the model client. (This is useful for reasoning models like QwQ). Defaults to False. + model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving + :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. Messages stored in model context will be used for speaker selection. The initial messages will be cleared when the team is reset. Raises: ValueError: If the number of participants is less than two or if the selector prompt is invalid. @@ -463,6 +505,64 @@ def selector_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | await Console(team.run_stream(task="What is 1 + 1?")) + asyncio.run(main()) + + A team with custom model context: + + .. code-block:: python + + import asyncio + + from autogen_core.model_context import BufferedChatCompletionContext + from autogen_ext.models.openai import OpenAIChatCompletionClient + + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.conditions import TextMentionTermination + from autogen_agentchat.teams import SelectorGroupChat + from autogen_agentchat.ui import Console + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + model_context = BufferedChatCompletionContext(buffer_size=5) + + async def lookup_hotel(location: str) -> str: + return f"Here are some hotels in {location}: hotel1, hotel2, hotel3." + + async def lookup_flight(origin: str, destination: str) -> str: + return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3." + + async def book_trip() -> str: + return "Your trip is booked!" + + travel_advisor = AssistantAgent( + "Travel_Advisor", + model_client, + tools=[book_trip], + description="Helps with travel planning.", + ) + hotel_agent = AssistantAgent( + "Hotel_Agent", + model_client, + tools=[lookup_hotel], + description="Helps with hotel booking.", + ) + flight_agent = AssistantAgent( + "Flight_Agent", + model_client, + tools=[lookup_flight], + description="Helps with flight booking.", + ) + termination = TextMentionTermination("TERMINATE") + team = SelectorGroupChat( + [travel_advisor, hotel_agent, flight_agent], + model_client=model_client, + termination_condition=termination, + model_context=model_context, + ) + await Console(team.run_stream(task="Book a 3-day trip to new york.")) + + asyncio.run(main()) """ @@ -492,6 +592,7 @@ def __init__( custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, emit_team_events: bool = False, model_client_streaming: bool = False, + model_context: ChatCompletionContext | None = None, ): super().__init__( participants, @@ -513,6 +614,7 @@ def __init__( self._max_selector_attempts = max_selector_attempts self._candidate_func = candidate_func self._model_client_streaming = model_client_streaming + self._model_context = model_context def _create_group_chat_manager_factory( self, @@ -545,6 +647,7 @@ def _create_group_chat_manager_factory( self._max_selector_attempts, self._candidate_func, self._emit_team_events, + self._model_context, self._model_client_streaming, ) @@ -560,6 +663,7 @@ def _to_config(self) -> SelectorGroupChatConfig: # selector_func=self._selector_func.dump_component() if self._selector_func else None, emit_team_events=self._emit_team_events, model_client_streaming=self._model_client_streaming, + model_context=self._model_context.dump_component() if self._model_context else None, ) @classmethod @@ -579,4 +683,5 @@ def _from_config(cls, config: SelectorGroupChatConfig) -> Self: # else None, emit_team_events=config.emit_team_events, model_client_streaming=config.model_client_streaming, + model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None, ) diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 37763e866950..e7e8ba436c2b 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -2,10 +2,28 @@ import json import logging import tempfile -from typing import Any, AsyncGenerator, List, Mapping, Sequence +from typing import Any, AsyncGenerator, Dict, List, Mapping, Sequence import pytest import pytest_asyncio +from autogen_core import AgentId, AgentRuntime, CancellationToken, FunctionCall, SingleThreadedAgentRuntime +from autogen_core.model_context import BufferedChatCompletionContext +from autogen_core.models import ( + AssistantMessage, + CreateResult, + FunctionExecutionResult, + FunctionExecutionResultMessage, + LLMMessage, + RequestUsage, + UserMessage, +) +from autogen_core.tools import FunctionTool +from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor +from autogen_ext.models.openai import OpenAIChatCompletionClient +from autogen_ext.models.replay import ReplayChatCompletionClient +from pydantic import BaseModel +from utils import FileLogHandler + from autogen_agentchat import EVENT_LOGGER_NAME from autogen_agentchat.agents import ( AssistantAgent, @@ -39,22 +57,6 @@ from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager from autogen_agentchat.ui import Console -from autogen_core import AgentId, AgentRuntime, CancellationToken, FunctionCall, SingleThreadedAgentRuntime -from autogen_core.models import ( - AssistantMessage, - CreateResult, - FunctionExecutionResult, - FunctionExecutionResultMessage, - LLMMessage, - RequestUsage, - UserMessage, -) -from autogen_core.tools import FunctionTool -from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor -from autogen_ext.models.openai import OpenAIChatCompletionClient -from autogen_ext.models.replay import ReplayChatCompletionClient -from pydantic import BaseModel -from utils import FileLogHandler logger = logging.getLogger(EVENT_LOGGER_NAME) logger.setLevel(logging.DEBUG) @@ -698,6 +700,60 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None: assert result2 == result +@pytest.mark.asyncio +async def test_selector_group_chat_with_model_context(runtime: AgentRuntime | None) -> None: + buffered_context = BufferedChatCompletionContext(buffer_size=5) + await buffered_context.add_message(UserMessage(content="[User] Prefilled message", source="user")) + + selector_group_chat_model_client = ReplayChatCompletionClient( + ["agent2", "agent1", "agent1", "agent2", "agent1", "agent2", "agent1"] + ) + agent_one_model_client = ReplayChatCompletionClient( + ["[Agent One] First generation", "[Agent One] Second generation", "[Agent One] Third generation", "TERMINATE"] + ) + agent_two_model_client = ReplayChatCompletionClient( + ["[Agent Two] First generation", "[Agent Two] Second generation", "[Agent Two] Third generation"] + ) + + agent1 = AssistantAgent("agent1", model_client=agent_one_model_client, description="Assistant agent 1") + agent2 = AssistantAgent("agent2", model_client=agent_two_model_client, description="Assistant agent 2") + + termination = TextMentionTermination("TERMINATE") + team = SelectorGroupChat( + participants=[agent1, agent2], + model_client=selector_group_chat_model_client, + termination_condition=termination, + runtime=runtime, + emit_team_events=True, + allow_repeated_speaker=True, + model_context=buffered_context, + ) + await team.run( + task="[GroupChat] Task", + ) + + messages_to_check = [ + "user: [User] Prefilled message", + "user: [GroupChat] Task", + "agent2: [Agent Two] First generation", + "agent1: [Agent One] First generation", + "agent1: [Agent One] Second generation", + "agent2: [Agent Two] Second generation", + "agent1: [Agent One] Third generation", + "agent2: [Agent Two] Third generation", + ] + + create_calls: List[Dict[str, Any]] = selector_group_chat_model_client.create_calls + for idx, call in enumerate(create_calls): + messages = call["messages"] + prompt = messages[0].content + prompt_lines = prompt.split("\n") + chat_history = [value for value in messages_to_check[max(0, idx - 3) : idx + 2]] + assert all( + line.strip() in prompt_lines for line in chat_history + ), f"Expected all lines {chat_history} to be in prompt, but got {prompt_lines}" + + @pytest.mark.asyncio async def test_selector_group_chat_with_team_event(runtime: AgentRuntime | None) -> None: model_client = ReplayChatCompletionClient(