Skip to content

Commit 5a7fb3a

Browse files
authored
[Model] Add ToolParser and MoE Config for Hunyuan A13B (#20820)
Signed-off-by: Asher Zhang <[email protected]>
1 parent 11dfdf2 commit 5a7fb3a

17 files changed

+1712
-4
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,11 @@ def main(args: argparse.Namespace):
586586
topk = config.num_experts_per_tok
587587
intermediate_size = config.moe_intermediate_size
588588
shard_intermediate_size = 2 * intermediate_size // args.tp_size
589+
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
590+
E = config.num_experts
591+
topk = config.moe_topk[0]
592+
intermediate_size = config.moe_intermediate_size[0]
593+
shard_intermediate_size = 2 * intermediate_size // args.tp_size
589594
else:
590595
# Support for llama4
591596
config = config.get_text_config()

docs/features/reasoning_outputs.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ vLLM currently supports the following reasoning models:
1414
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` ||
1515
| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` |||
1616
| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` ||
17+
| [Hunyuan A13B series](https://huggingface.co/collections/tencent/hunyuan-a13b-685ec38e5b46321e3ea7c4be) | `hunyuan_a13b` | `guided_json`, `guided_regex` ||
1718

1819
!!! note
1920
IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.

docs/features/tool_calling.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,16 @@ Supported models:
288288

289289
Flags: `--tool-call-parser kimi_k2`
290290

291+
### Hunyuan Models (`hunyuan_a13b`)
292+
293+
Supported models:
294+
295+
* `tencent/Hunyuan-A13B-Instruct` (chat template already included huggingface model file.)
296+
297+
Flags:
298+
* For non-reasoning: `--tool-call-parser hunyuan_a13b`
299+
* For reasoning: `--tool-call-parser hunyuan_a13b --reasoning-parser hunyuan_a13b --enable_reasoning`
300+
291301
### Models with Pythonic Tool Calls (`pythonic`)
292302

293303
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
{% set loop_messages = messages %}
2+
{% if tools %}
3+
{% set weekday_map = {'Monday': '星期一', 'Tuesday': '星期二', 'Wednesday': '星期三', 'Thursday': '星期四', 'Friday': '星期五', 'Saturday': '星期六', 'Sunday': '星期日'} %}
4+
{% set weekday_cn = weekday_map[strftime_now('%A')] %}
5+
{% set datetime_str = strftime_now('%Y-%m-%d %H:%M:%S') %}
6+
{% set datetime_str = datetime_str + ' ' + weekday_cn %}
7+
{% for message in loop_messages %}
8+
{% if 'content' in message %}
9+
{% set content = message['content'] %}
10+
{% else %}
11+
{% set content = '' %}
12+
{% endif %}
13+
{% if loop.index0 == 0 %}
14+
{% set content_tmp = '你是一位函数组合专家。你会得到一个问题和一组可能的函数。根据问题,你需要进行一个或多个函数/工具调用以实现目的。
15+
如果没有一个函数可以使用,请直接使用自然语言回复用户,以助手:开头。
16+
如果给定的问题缺少函数所需的参数,请使用自然语言进行提问,向用户询问必要信息,以助手:开头。
17+
如果调用结果已经足够回答用户问题,请对历史结果进行总结,使用自然语言回复用户,以助手:开头。
18+
你应该只在工具调用部分返回函数调用。如果你决定调用任何函数,你必须将其格式化为<tool_calls>[{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},...]</tool_calls>。你不应该在回复中包含任何其他文本。以下是你可以调用的函数列表,格式为JSON。
19+
' %}
20+
{% set content_tmp = content_tmp + '
21+
' + tools | tojson + '
22+
' %}
23+
{% if message['role'] == 'system' %}
24+
{% set content_tmp = content_tmp + '
25+
额外要求:
26+
' + content + '
27+
28+
如果你决定返回函数调用,请将其格式化为<tool_calls>[{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},...]</tool_calls>,不得包含其他文本。如果额外要求里有格式要求,请忽略,以此处为准。
29+
否则,请参考开头说的三种情况,以助手:开头进行回复。
30+
31+
如果额外要求里有时间信息,就以额外要求里的时间为准,否则,参考当前时间:' + datetime_str %}
32+
{% set content = '<|startoftext|>' + content_tmp + '<|extra_4|>' %}
33+
{% elif message['role'] == 'user' %}
34+
{% set content_tmp = content_tmp + '
35+
如果你决定返回函数调用,请将其格式化为<tool_calls>[{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},...]</tool_calls>,不得包含其他文本。
36+
否则,请参考开头说的三种情况,以助手:开头进行回复。
37+
38+
当前时间:' + datetime_str %}
39+
{% set content_tmp = '<|startoftext|>' + content_tmp + '<|extra_4|>'%}
40+
{% set content = content_tmp + '用户:' + content + '<|extra_0|>' %}
41+
{% endif %}
42+
{% else %}
43+
{% if message['role'] == 'user' %}
44+
{% set content = '用户:' + content + '<|extra_0|>' %}
45+
{% elif message['role'] == 'assistant' %}
46+
{% if 'tool_calls' in message %}
47+
{% set tool_calls = message['tool_calls'] %}
48+
{% set ns = namespace(tool_calls="[") %}
49+
{% for tool_call in tool_calls %}
50+
{% set function = tool_call['function'] %}
51+
{% set name = function['name'] %}
52+
{% set ns.tool_calls = ns.tool_calls + '{"name": "' + name + '", '%}
53+
{% set arguments = function['arguments'] %}
54+
{% if arguments is not string %}
55+
{% set arguments = arguments | tojson %}
56+
{% endif %}
57+
{% set ns.tool_calls = ns.tool_calls + '"arguments": ' + arguments + '}' %}
58+
{% if not loop.last %}
59+
{% set ns.tool_calls = ns.tool_calls + ', '%}
60+
{% endif %}
61+
{% endfor %}
62+
{% set ns.tool_calls = ns.tool_calls + ']' %}
63+
{% set content = content + '<tool_calls>' + ns.tool_calls + '</tool_calls>' %}
64+
{% else %}
65+
{% set content = '助手:' + content %}
66+
{% endif %}
67+
{% set content = content + '<|eos|>' %}
68+
{% elif message['role'] == 'tool' %}
69+
{% if content is not string %}
70+
{set content = content | tojson }
71+
{% endif %}
72+
{% set content = '<tool_response>' + content + '</tool_response>' %}
73+
{% set content = content + '<|extra_0|>' %}
74+
{% endif %}
75+
{% endif %}
76+
{{- content -}}
77+
{% endfor %}
78+
{% else %}
79+
{% set context = {'has_head': true} %}
80+
{% for message in loop_messages %}
81+
{% if 'content' in message %}
82+
{% set content = message['content'] %}
83+
{% else %}
84+
{% set content = '' %}
85+
{% endif %}
86+
{% if loop.index0 == 0 %}
87+
{% if content == '' %}
88+
{% set _ = context.update({'has_head': false}) %}
89+
{% elif message['role'] == 'system' %}
90+
{% set content = '<|startoftext|>' + content + '<|extra_4|>' %}
91+
{% endif %}
92+
{% endif %}
93+
{% if message['role'] == 'user' %}
94+
{% if loop.index0 == 1 and not context.has_head %}
95+
{% set content = '<|startoftext|>' + content %}
96+
{% endif %}
97+
{% if loop.index0 == 1 and context.has_head %}
98+
{% set content = content + '<|extra_0|>' %}
99+
{% else %}
100+
{% set content = '<|startoftext|>' + content + '<|extra_0|>' %}
101+
{% endif %}
102+
{% elif message['role'] == 'assistant' %}
103+
{% set content = content + '<|eos|>' %}
104+
{% elif message['role'] == 'tool' %}
105+
{% set content = content + '<|extra_0|>' %}
106+
{% endif %}
107+
{{- content -}}
108+
{% endfor %}
109+
{% endif %}
110+
{%- if enable_thinking is defined and enable_thinking is false %}
111+
{{- '<think>\n\n</think>\n' }}
112+
{%- endif %}
113+
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# ruff: noqa: E501
4+
5+
import json
6+
from unittest.mock import MagicMock
7+
8+
import pytest
9+
10+
from tests.entrypoints.openai.tool_parsers.utils import (
11+
run_tool_extraction, run_tool_extraction_streaming)
12+
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
13+
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
14+
15+
16+
def make_tool_call(name, arguments):
17+
return ToolCall(type="function",
18+
function=FunctionCall(name=name,
19+
arguments=json.dumps(arguments)))
20+
21+
22+
# TODO: add reason prefix and suffix.
23+
24+
25+
@pytest.mark.parametrize(
26+
"model_output,expected_tool_calls,expected_content",
27+
[
28+
# No tool call
29+
("How can I help you today?", [], "How can I help you today?"),
30+
# Single tool call, no content
31+
(
32+
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}]</tool_calls>", #noqa: E501
33+
[
34+
make_tool_call("get_weather", {
35+
"city": "San Francisco",
36+
"metric": "celsius"
37+
})
38+
],
39+
None),
40+
# Multiple tool calls
41+
(
42+
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}, {\"name\": \"register_user\", \"arguments\": {\"name\": \"John Doe\", \"age\": 37, \"address\": {\"city\": \"San Francisco\", \"state\": \"CA\"}, \"role\": null, \"passed_test\": true, \"aliases\": [\"John\", \"Johnny\"]}}]</tool_calls>", #noqa: E501
43+
[
44+
make_tool_call("get_weather", {
45+
"city": "San Francisco",
46+
"metric": "celsius"
47+
}),
48+
make_tool_call(
49+
"register_user", {
50+
"name": "John Doe",
51+
"age": 37,
52+
"address": {
53+
"city": "San Francisco",
54+
"state": "CA"
55+
},
56+
"role": None,
57+
"passed_test": True,
58+
"aliases": ["John", "Johnny"]
59+
})
60+
],
61+
None),
62+
# Content before tool call
63+
(
64+
"I will call the tool now. <tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Boston\"}}]</tool_calls>", #noqa: E501
65+
[make_tool_call("get_weather", {"city": "Boston"})],
66+
"I will call the tool now. "),
67+
# Content after tool call (should be stripped)
68+
(
69+
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Seattle\"}}]</tool_calls>\nThank you!", #noqa: E501
70+
[make_tool_call("get_weather", {"city": "Seattle"})],
71+
None),
72+
(
73+
"<tool_calls>[{\"name\": \"complex_tool\", \"arguments\": {\"level1\": {\"level2\": {\"level3\": {\"value\": 123}}}}}]</tool_calls>",
74+
[
75+
make_tool_call(
76+
"complex_tool",
77+
{"level1": {
78+
"level2": {
79+
"level3": {
80+
"value": 123
81+
}
82+
}
83+
}})
84+
],
85+
None,
86+
),
87+
])
88+
def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls,
89+
expected_content):
90+
mock_tokenizer = MagicMock()
91+
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
92+
"hunyuan_a13b")(mock_tokenizer)
93+
content, tool_calls = run_tool_extraction(tool_parser,
94+
model_output,
95+
streaming=False)
96+
97+
# align the random id.
98+
for idx in range(len(tool_calls)):
99+
tool_calls[idx].id = expected_tool_calls[idx].id
100+
assert tool_calls == expected_tool_calls
101+
assert content == expected_content
102+
103+
104+
# Streaming test: simulate incremental output
105+
@pytest.mark.parametrize("model_deltas,expected_tool_calls", [
106+
([
107+
"<tool_calls>[{\"name\": \"get_weather\", ",
108+
"\"arguments\": {\"city\": \"San Francisco\", ",
109+
"\"metric\": \"celsius\"}}]", "</tool_calls>"
110+
], [
111+
make_tool_call("get_weather", {
112+
"city": "San Francisco",
113+
"metric": "celsius"
114+
})
115+
]),
116+
([
117+
"<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
118+
" {\"city\": \"Boston\"}", "}]", "</tool_calls>"
119+
], [make_tool_call("get_weather", {"city": "Boston"})]),
120+
([
121+
"", "<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
122+
" {\"city\": \"Boston\"}", "}]", "</tool_calls>", "\n</answer>"
123+
], [make_tool_call("get_weather", {"city": "Boston"})]),
124+
pytest.param([
125+
"<tool_calls>[{\"name\": \"complex_tool\",", " \"arguments\": ",
126+
" {\"level1\": {\"level2\": ", "{\"level3\": {\"value\": 123}}}}}",
127+
"]</tool_calls>"
128+
], [
129+
make_tool_call("complex_tool",
130+
{"level1": {
131+
"level2": {
132+
"level3": {
133+
"value": 123
134+
}
135+
}
136+
}})
137+
],
138+
marks=pytest.mark.xfail(
139+
reason="stream parsing not support nested json yet.")),
140+
])
141+
def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls):
142+
mock_tokenizer = MagicMock()
143+
144+
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
145+
"hunyuan_a13b")(mock_tokenizer)
146+
reconstructor = run_tool_extraction_streaming(
147+
tool_parser, model_deltas, assert_one_tool_per_delta=False)
148+
149+
# align the random id.
150+
for idx in range(len(reconstructor.tool_calls)):
151+
reconstructor.tool_calls[idx].id = expected_tool_calls[idx].id
152+
153+
assert reconstructor.tool_calls == expected_tool_calls

tests/reasoning/test_hunyuan_reasoning_parser.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
"reasoning_content": "This is a reasoning section",
3131
"content": None,
3232
}
33+
34+
COMPLETE_REASONING_WITH_SYMBOL = {
35+
"output": f"{START_REASONING}This is a reasoning section!{START_RESPONSE}",
36+
"reasoning_content": "This is a reasoning section!",
37+
"content": None,
38+
}
3339
NO_REASONING = {
3440
"output": "This is content",
3541
"reasoning_content": None,
@@ -70,6 +76,11 @@
7076
COMPLETE_REASONING,
7177
id="complete_reasoning",
7278
),
79+
pytest.param(
80+
False,
81+
COMPLETE_REASONING_WITH_SYMBOL,
82+
id="complete_reasoning_with_symbol",
83+
),
7384
pytest.param(
7485
False,
7586
NO_REASONING,

vllm/entrypoints/openai/serving_chat.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -613,8 +613,13 @@ async def chat_completion_stream_generator(
613613
previous_text = previous_texts[i]
614614
previous_token_ids = all_previous_token_ids[i]
615615
current_text = previous_text + delta_text
616-
current_token_ids = previous_token_ids + list(
617-
output.token_ids)
616+
617+
# avoid the None + list error.
618+
if previous_token_ids:
619+
current_token_ids = previous_token_ids + list(
620+
output.token_ids)
621+
else:
622+
current_token_ids = list(output.token_ids)
618623

619624
# handle streaming deltas for tools with named tool_choice
620625
if tool_choice_function_name:
@@ -1077,9 +1082,17 @@ async def chat_completion_full_generator(
10771082
else:
10781083
# FOR NOW make it a chat message; we will have to detect
10791084
# the type to make it later.
1085+
ret_content = content
1086+
1087+
# try to use content return from tool parser first,
1088+
# tool parser may do some modify for the content.
1089+
if (tool_call_info.content
1090+
and len(tool_call_info.content) > 0):
1091+
ret_content = tool_call_info.content
1092+
10801093
message = ChatMessage(role=role,
10811094
reasoning_content=reasoning_content,
1082-
content=content)
1095+
content=ret_content)
10831096

10841097
# undetermined case that is still important to handle
10851098
else:

vllm/entrypoints/openai/tool_parsers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
77
from .granite_tool_parser import GraniteToolParser
88
from .hermes_tool_parser import Hermes2ProToolParser
9+
from .hunyuan_a13b_tool_parser import HunyuanA13BToolParser
910
from .internlm2_tool_parser import Internlm2ToolParser
1011
from .jamba_tool_parser import JambaToolParser
1112
from .kimi_k2_tool_parser import KimiK2ToolParser
@@ -23,5 +24,5 @@
2324
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
2425
"Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
2526
"DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser",
26-
"KimiK2ToolParser"
27+
"KimiK2ToolParser", "HunyuanA13BToolParser"
2728
]

0 commit comments

Comments
 (0)