Skip to content

Add model_context to SelectorGroupChat for enhanced speaker selection #6330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
)

# Append all messages to thread
self._message_thread.extend(message.messages)
await self.update_message_thread(message.messages)

# Check termination condition after processing all messages
if await self._apply_termination_condition(message.messages):
Expand All @@ -139,17 +139,19 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
cancellation_token=ctx.cancellation_token,
)

async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
self._message_thread.extend(messages)

@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
try:
# Append the message to the message thread and construct the delta.
delta: List[BaseAgentEvent | BaseChatMessage] = []
if message.agent_response.inner_messages is not None:
for inner_message in message.agent_response.inner_messages:
self._message_thread.append(inner_message)
delta.append(inner_message)
self._message_thread.append(message.agent_response.chat_message)
delta.append(message.agent_response.chat_message)
await self.update_message_thread(delta)

# Check if the conversation should be terminated.
if await self._apply_termination_condition(delta, increment_turn_count=True):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess
if message.agent_response.inner_messages is not None:
for inner_message in message.agent_response.inner_messages:
delta.append(inner_message)
self._message_thread.append(message.agent_response.chat_message)
await self.update_message_thread([message.agent_response.chat_message])
delta.append(message.agent_response.chat_message)

if self._termination_condition is not None:
Expand Down Expand Up @@ -263,7 +263,7 @@ async def _reenter_outer_loop(self, cancellation_token: CancellationToken) -> No
)

# Save my copy
self._message_thread.append(ledger_message)
await self.update_message_thread([ledger_message])

# Log it to the output topic.
await self.publish_message(
Expand Down Expand Up @@ -376,7 +376,7 @@ async def _orchestrate_step(self, cancellation_token: CancellationToken) -> None

# Broadcast the next step
message = TextMessage(content=progress_ledger["instruction_or_question"]["answer"], source=self._name)
self._message_thread.append(message) # My copy
await self.update_message_thread([message]) # My copy

await self._log_message(f"Next Speaker: {progress_ledger['next_speaker']['answer']}")
# Log it to the output topic.
Expand Down Expand Up @@ -458,7 +458,7 @@ async def _prepare_final_answer(self, reason: str, cancellation_token: Cancellat
assert isinstance(response.content, str)
message = TextMessage(content=response.content, source=self._name)

self._message_thread.append(message) # My copy
await self.update_message_thread([message]) # My copy

# Log it to the output topic.
await self.publish_message(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
from inspect import iscoroutinefunction
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast

from autogen_core import AgentRuntime, Component, ComponentModel
from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel
from autogen_core.model_context import (
ChatCompletionContext,
UnboundedChatCompletionContext,
)
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
LLMMessage,
ModelFamily,
SystemMessage,
UserMessage,
Expand All @@ -22,6 +27,7 @@
from ...messages import (
BaseAgentEvent,
BaseChatMessage,
HandoffMessage,
MessageFactory,
ModelClientStreamingChunkEvent,
SelectorEvent,
Expand Down Expand Up @@ -65,6 +71,7 @@
max_selector_attempts: int,
candidate_func: Optional[CandidateFuncType],
emit_team_events: bool,
model_context: ChatCompletionContext | None,
model_client_streaming: bool = False,
) -> None:
super().__init__(
Expand All @@ -90,13 +97,19 @@
self._candidate_func = candidate_func
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func)
self._model_client_streaming = model_client_streaming
if model_context is not None:
self._model_context = model_context
else:
self._model_context = UnboundedChatCompletionContext()
self._cancellation_token = CancellationToken()

async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass

async def reset(self) -> None:
self._current_turn = 0
self._message_thread.clear()
await self._model_context.clear()
if self._termination_condition is not None:
await self._termination_condition.reset()
self._previous_speaker = None
Expand All @@ -112,16 +125,37 @@
async def load_state(self, state: Mapping[str, Any]) -> None:
selector_state = SelectorManagerState.model_validate(state)
self._message_thread = [self._message_factory.create(msg) for msg in selector_state.message_thread]
await self._add_messages_to_context(
self._model_context, [msg for msg in self._message_thread if isinstance(msg, BaseChatMessage)]
)
self._current_turn = selector_state.current_turn
self._previous_speaker = selector_state.previous_speaker

@staticmethod
async def _add_messages_to_context(
model_context: ChatCompletionContext,
messages: Sequence[BaseChatMessage],
) -> None:
"""
Add incoming messages to the model context.
"""
for msg in messages:
if isinstance(msg, HandoffMessage):
for llm_msg in msg.context:
await model_context.add_message(llm_msg)

Check warning on line 145 in python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py#L144-L145

Added lines #L144 - L145 were not covered by tests
await model_context.add_message(msg.to_model_message())

async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
self._message_thread.extend(messages)
base_chat_messages = [m for m in messages if isinstance(m, BaseChatMessage)]
await self._add_messages_to_context(self._model_context, base_chat_messages)

async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
"""Selects the next speaker in a group chat using a ChatCompletion client,
with the selector function as override if it returns a speaker name.

A key assumption is that the agent type is the same as the topic type, which we use as the agent name.
"""

# Use the selector function if provided.
if self._selector_func is not None:
if self._is_selector_func_async:
Expand Down Expand Up @@ -163,18 +197,6 @@

assert len(participants) > 0

# Construct the history of the conversation.
history_messages: List[str] = []
for msg in thread:
if not isinstance(msg, BaseChatMessage):
# Only process chat messages.
continue
message = f"{msg.source}: {msg.to_model_text()}"
history_messages.append(
message.rstrip() + "\n\n"
) # Create some consistency for how messages are separated in the transcript
history = "\n".join(history_messages)

# Construct agent roles.
# Each agent sould appear on a single line.
roles = ""
Expand All @@ -184,17 +206,34 @@

# Select the next speaker.
if len(participants) > 1:
agent_name = await self._select_speaker(roles, participants, history, self._max_selector_attempts)
agent_name = await self._select_speaker(roles, participants, self._max_selector_attempts)
else:
agent_name = participants[0]
self._previous_speaker = agent_name
trace_logger.debug(f"Selected speaker: {agent_name}")
return agent_name

async def _select_speaker(self, roles: str, participants: List[str], history: str, max_attempts: int) -> str:
def construct_message_history(self, message_history: List[LLMMessage]) -> str:
# Construct the history of the conversation.
history_messages: List[str] = []
for msg in message_history:
if isinstance(msg, UserMessage) or isinstance(msg, AssistantMessage):
message = f"{msg.source}: {msg.content}"
history_messages.append(
message.rstrip() + "\n\n"
) # Create some consistency for how messages are separated in the transcript

history: str = "\n".join(history_messages)
return history

async def _select_speaker(self, roles: str, participants: List[str], max_attempts: int) -> str:
model_context_messages = await self._model_context.get_messages()
model_context_history = self.construct_message_history(model_context_messages)

select_speaker_prompt = self._selector_prompt.format(
roles=roles, participants=str(participants), history=history
roles=roles, participants=str(participants), history=model_context_history
)

select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage]
if ModelFamily.is_openai(self._model_client.model_info["family"]):
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
Expand Down Expand Up @@ -312,6 +351,7 @@
max_selector_attempts: int = 3
emit_team_events: bool = False
model_client_streaming: bool = False
model_context: ComponentModel | None = None


class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
Expand Down Expand Up @@ -349,6 +389,8 @@
Make sure your custom message types are subclasses of :class:`~autogen_agentchat.messages.BaseAgentEvent` or :class:`~autogen_agentchat.messages.BaseChatMessage`.
emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False.
model_client_streaming (bool, optional): Whether to use streaming for the model client. (This is useful for reasoning models like QwQ). Defaults to False.
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving
:class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. Messages stored in model context will be used for speaker selection. The initial messages will be cleared when the team is reset.

Raises:
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
Expand Down Expand Up @@ -463,6 +505,64 @@
await Console(team.run_stream(task="What is 1 + 1?"))


asyncio.run(main())

A team with custom model context:

.. code-block:: python

import asyncio

from autogen_core.model_context import BufferedChatCompletionContext
from autogen_ext.models.openai import OpenAIChatCompletionClient

from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import TextMentionTermination
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.ui import Console


async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
model_context = BufferedChatCompletionContext(buffer_size=5)

async def lookup_hotel(location: str) -> str:
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."

async def lookup_flight(origin: str, destination: str) -> str:
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."

async def book_trip() -> str:
return "Your trip is booked!"

travel_advisor = AssistantAgent(
"Travel_Advisor",
model_client,
tools=[book_trip],
description="Helps with travel planning.",
)
hotel_agent = AssistantAgent(
"Hotel_Agent",
model_client,
tools=[lookup_hotel],
description="Helps with hotel booking.",
)
flight_agent = AssistantAgent(
"Flight_Agent",
model_client,
tools=[lookup_flight],
description="Helps with flight booking.",
)
termination = TextMentionTermination("TERMINATE")
team = SelectorGroupChat(
[travel_advisor, hotel_agent, flight_agent],
model_client=model_client,
termination_condition=termination,
model_context=model_context,
)
await Console(team.run_stream(task="Book a 3-day trip to new york."))


asyncio.run(main())
"""

Expand Down Expand Up @@ -492,6 +592,7 @@
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
model_client_streaming: bool = False,
model_context: ChatCompletionContext | None = None,
):
super().__init__(
participants,
Expand All @@ -513,6 +614,7 @@
self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func
self._model_client_streaming = model_client_streaming
self._model_context = model_context

def _create_group_chat_manager_factory(
self,
Expand Down Expand Up @@ -545,6 +647,7 @@
self._max_selector_attempts,
self._candidate_func,
self._emit_team_events,
self._model_context,
self._model_client_streaming,
)

Expand All @@ -560,6 +663,7 @@
# selector_func=self._selector_func.dump_component() if self._selector_func else None,
emit_team_events=self._emit_team_events,
model_client_streaming=self._model_client_streaming,
model_context=self._model_context.dump_component() if self._model_context else None,
)

@classmethod
Expand All @@ -579,4 +683,5 @@
# else None,
emit_team_events=config.emit_team_events,
model_client_streaming=config.model_client_streaming,
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
)
Loading
Loading