Skip to content

Commit 82c50ce

Browse files
committed
FIX: fix errors and pyright mypy issues
1 parent d8b0ac0 commit 82c50ce

File tree

6 files changed

+54
-69
lines changed

6 files changed

+54
-69
lines changed

python/packages/autogen-ext/src/autogen_ext/models/openai/_message_transform.py

+27-11
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,25 @@
3030

3131
from ._utils import assert_valid_name, func_call_to_oai
3232

33+
EMPTY: Dict[str, Any] = {}
34+
3335

3436
# ===Mini Transformers===
35-
def _assert_valid_name(message: LLMMessage, context: Dict[str, Any]):
37+
def _assert_valid_name(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, Any]:
38+
assert isinstance(message, (UserMessage, AssistantMessage))
3639
assert_valid_name(message.source)
37-
result: Dict[str, Any] = {}
38-
return result
40+
return EMPTY
3941

4042

41-
def _set_role(role: str):
43+
def _set_role(role: str) -> Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]:
4244
def inner(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, Any]:
4345
return {"role": role}
4446

4547
return inner
4648

4749

4850
def _set_name(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, Any]:
51+
assert isinstance(message, (UserMessage, AssistantMessage))
4952
assert_valid_name(message.source)
5053
return {"name": message.source}
5154

@@ -54,13 +57,16 @@ def _set_content_direct(message: LLMMessage, context: Dict[str, Any]) -> Dict[st
5457
return {"content": message.content}
5558

5659

57-
def _set_prepend_text_content(message: UserMessage, context: Dict[str, Any]) -> Dict[str, Any]:
60+
def _set_prepend_text_content(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, Any]:
61+
assert isinstance(message, (UserMessage, AssistantMessage))
62+
assert isinstance(message.content, str)
5863
prepend = context.get("prepend_name", False)
5964
prefix = f"{message.source} said:\n" if prepend else ""
6065
return {"content": prefix + message.content}
6166

6267

63-
def _set_multimodal_content(message: UserMessage, context: Dict[str, Any]) -> Dict[str, Any]:
68+
def _set_multimodal_content(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, Any]:
69+
assert isinstance(message, (UserMessage, AssistantMessage))
6470
prepend = context.get("prepend_name", False)
6571
parts: List[ChatCompletionContentPartParam] = []
6672

@@ -79,13 +85,16 @@ def _set_multimodal_content(message: UserMessage, context: Dict[str, Any]) -> Di
7985
return {"content": parts}
8086

8187

82-
def _set_tool_calls(message: AssistantMessage, context: Dict[str, Any]) -> Dict[str, Any]:
88+
def _set_tool_calls(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, Any]:
89+
assert isinstance(message.content, list)
90+
assert isinstance(message, AssistantMessage)
8391
return {
8492
"tool_calls": [func_call_to_oai(x) for x in message.content],
8593
}
8694

8795

88-
def _set_thought_as_content(message: AssistantMessage, context: Dict[str, Any]) -> Dict[str, Any]:
96+
def _set_thought_as_content(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, Any]:
97+
assert isinstance(message, AssistantMessage)
8998
return {"content": message.thought}
9099

91100

@@ -187,6 +196,7 @@ def user_condition(message: LLMMessage, context: Dict[str, Any]) -> str:
187196

188197

189198
def assistant_condition(message: LLMMessage, context: Dict[str, Any]) -> str:
199+
assert isinstance(message, AssistantMessage)
190200
if isinstance(message.content, list):
191201
if message.thought is not None:
192202
return "thought"
@@ -208,13 +218,19 @@ def assistant_condition(message: LLMMessage, context: Dict[str, Any]) -> str:
208218
}
209219

210220

211-
def function_execution_result_message(message: LLMMessage, context: Dict[str, Any]) -> List[Dict[str, Any]]:
221+
def function_execution_result_message(
222+
message: LLMMessage, context: Dict[str, Any]
223+
) -> list[ChatCompletionToolMessageParam]:
224+
assert isinstance(message, FunctionExecutionResultMessage)
212225
return [
213226
ChatCompletionToolMessageParam(content=x.content, role="tool", tool_call_id=x.call_id) for x in message.content
214227
]
215228

216229

217-
def function_execution_result_message_gemini(message: LLMMessage, context: Dict[str, Any]) -> List[Dict[str, Any]]:
230+
def function_execution_result_message_gemini(
231+
message: LLMMessage, context: Dict[str, Any]
232+
) -> list[ChatCompletionToolMessageParam]:
233+
assert isinstance(message, FunctionExecutionResultMessage)
218234
return [
219235
ChatCompletionToolMessageParam(content=x.content if x.content else " ", role="tool", tool_call_id=x.call_id)
220236
for x in message.content
@@ -282,4 +298,4 @@ def function_execution_result_message_gemini(message: LLMMessage, context: Dict[
282298
for model in __unknown_models:
283299
register_transformer("openai", model, __BASE_TRANSFORMER_MAP)
284300

285-
register_transformer("openai", "default", __BASE_TRANSFORMER_MAP)
301+
# register_transformer("openai", "default", __BASE_TRANSFORMER_MAP)

python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py

+7-40
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import (
1212
Any,
1313
AsyncGenerator,
14+
Callable,
1415
Dict,
1516
List,
1617
Mapping,
@@ -58,7 +59,6 @@
5859
ChatCompletionContentPartParam,
5960
ChatCompletionContentPartTextParam,
6061
ChatCompletionMessageParam,
61-
ChatCompletionMessageToolCallParam,
6262
ChatCompletionRole,
6363
ChatCompletionSystemMessageParam,
6464
ChatCompletionToolMessageParam,
@@ -195,17 +195,6 @@ def system_message_to_oai(message: SystemMessage) -> ChatCompletionSystemMessage
195195
)
196196

197197

198-
def _old_func_call_to_oai(message: FunctionCall) -> ChatCompletionMessageToolCallParam:
199-
return ChatCompletionMessageToolCallParam(
200-
id=message.id,
201-
function={
202-
"arguments": message.arguments,
203-
"name": message.name,
204-
},
205-
type="function",
206-
)
207-
208-
209198
def tool_message_to_oai(
210199
message: FunctionExecutionResultMessage,
211200
) -> Sequence[ChatCompletionToolMessageParam]:
@@ -248,25 +237,16 @@ def to_oai_type(
248237
}
249238
transformers = get_transformer("openai", model_family)
250239

251-
def raise_value_error(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, Any]:
240+
def raise_value_error(message: LLMMessage, context: Dict[str, Any]) -> ChatCompletionMessageParam:
252241
raise ValueError(f"Unknown message type: {type(message)}")
253242

254-
transformer = transformers.get(type(message), raise_value_error)
243+
transformer: Callable[
244+
[LLMMessage, Dict[str, Any]], Union[ChatCompletionMessageParam, Sequence[ChatCompletionMessageParam]]
245+
] = transformers.get(type(message), raise_value_error)
255246
result = transformer(message, context)
256247
if isinstance(result, list):
257-
return result
258-
return [result]
259-
260-
261-
def _old_to_oai_type(message: LLMMessage, prepend_name: bool = False) -> Sequence[ChatCompletionMessageParam]:
262-
if isinstance(message, SystemMessage):
263-
return [system_message_to_oai(message)]
264-
elif isinstance(message, UserMessage):
265-
return [user_message_to_oai(message, prepend_name)]
266-
elif isinstance(message, AssistantMessage):
267-
return [assistant_message_to_oai(message)]
268-
else:
269-
return tool_message_to_oai(message)
248+
return cast(List[ChatCompletionMessageParam], result)
249+
return cast(List[ChatCompletionMessageParam], [result])
270250

271251

272252
def calculate_vision_tokens(image: Image, detail: str = "auto") -> int:
@@ -364,19 +344,6 @@ def normalize_name(name: str) -> str:
364344
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
365345

366346

367-
def _old_assert_valid_name(name: str) -> str:
368-
"""
369-
Ensure that configured names are valid, raises ValueError if not.
370-
371-
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
372-
"""
373-
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
374-
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
375-
if len(name) > 64:
376-
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
377-
return name
378-
379-
380347
@dataclass
381348
class CreateParams:
382349
messages: List[ChatCompletionMessageParam]

python/packages/autogen-ext/src/autogen_ext/models/openai/_utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import re
2-
from typing import Any, Dict, List, Union
32

43
from autogen_core import FunctionCall
54
from openai.types.chat import ChatCompletionMessageToolCallParam

python/packages/autogen-ext/src/autogen_ext/transformation/__init__.py

-4
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
register_transformer,
77
)
88
from .types import (
9-
BuilderFunc,
10-
BuilderMap,
119
TransformerFunc,
1210
TransformerMap,
1311
)
@@ -20,6 +18,4 @@
2018
"MESSAGE_TRANSFORMERS",
2119
"TransformerMap",
2220
"TransformerFunc",
23-
"BuilderMap",
24-
"BuilderFunc",
2521
]

python/packages/autogen-ext/src/autogen_ext/transformation/registry.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
from collections import defaultdict
2-
from typing import Any, Callable, Dict, List, Type
2+
from typing import Any, Callable, Dict, List
33

4-
from autogen_core.models import LLMMessage, ModelFamily
4+
from autogen_core.models import LLMMessage
55

66
from .types import (
7-
BuilderMap,
87
TransformerFunc,
98
TransformerMap,
109
)
1110

1211
# Global registry of model family → message transformer map
1312
# Each model family (e.g. "gpt-4o", "gemini-1.5-flash") maps to a dict of LLMMessage type → transformer function
1413
MESSAGE_TRANSFORMERS: Dict[str, Dict[str, TransformerMap]] = defaultdict(dict)
15-
MESSAGE_BUILDERS: Dict[str, BuilderMap] = {}
1614

1715

1816
def build_transformer_func(
@@ -69,7 +67,7 @@ def transformer(message: LLMMessage, context: Dict[str, Any]) -> Any:
6967
return transformer
7068

7169

72-
def register_transformer(api: str, model_family: str, transformer_map: TransformerMap):
70+
def register_transformer(api: str, model_family: str, transformer_map: TransformerMap) -> None:
7371
"""
7472
Registers a transformer map for a given model family.
7573
@@ -82,6 +80,17 @@ def register_transformer(api: str, model_family: str, transformer_map: Transform
8280
MESSAGE_TRANSFORMERS[api][model_family] = transformer_map
8381

8482

83+
def _find_model_family(api: str, model: str) -> str:
84+
"""
85+
Finds the best matching model family for the given model.
86+
Search via prefix matching (e.g. "gpt-4o" → "gpt-4o-1.0").
87+
"""
88+
for family in MESSAGE_TRANSFORMERS[api].keys():
89+
if model.startswith(family):
90+
return family
91+
return "default"
92+
93+
8594
def get_transformer(api: str, model_family: str) -> TransformerMap:
8695
"""
8796
Returns the registered transformer map for the given model family.
@@ -95,9 +104,10 @@ def get_transformer(api: str, model_family: str) -> TransformerMap:
95104
96105
Keeping this as a function (instead of direct dict access) improves long-term flexibility.
97106
"""
98-
transformer = MESSAGE_TRANSFORMERS.get(api, {}).get(model_family, {})
99-
if not transformer:
100-
transformer = MESSAGE_TRANSFORMERS.get("default", {}).get("default", {})
107+
108+
model = _find_model_family(api, model_family)
109+
110+
transformer = MESSAGE_TRANSFORMERS.get(api, {}).get(model, {})
101111

102112
if not transformer:
103113
raise ValueError(f"No transformer found for model family '{model_family}'")
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
from typing import Any, Callable, Dict, List, Type
1+
from typing import Any, Callable, Dict, Type
22

3-
from autogen_core.models import LLMMessage, ModelFamily
3+
from autogen_core.models import LLMMessage
44

55
TransformerFunc = Callable[[LLMMessage, Dict[str, Any]], Any]
66
TransformerMap = Dict[Type[LLMMessage], TransformerFunc]
7-
8-
BuilderFunc = Callable[[List[Any], Dict[str, Any]], Any]
9-
BuilderMap = Dict[Type[LLMMessage], BuilderFunc]

0 commit comments

Comments
 (0)