From f64adb487b39aa140183e38f5b3b7afcff2130dc Mon Sep 17 00:00:00 2001 From: Asher Zhang Date: Thu, 10 Jul 2025 17:11:33 +0800 Subject: [PATCH 1/5] [Model] Add ToolParser for Hunyuan A13B. - add stream and non stream support - reason parser use regex package. - reason parser: add missing function. Signed-off-by: Asher Zhang --- vllm/entrypoints/openai/serving_chat.py | 19 +- .../openai/tool_parsers/__init__.py | 3 +- .../tool_parsers/hunyuan_a13b_tool_parser.py | 448 ++++++++++++++++++ .../hunyuan_a13b_reasoning_parser.py | 7 + 4 files changed, 473 insertions(+), 4 deletions(-) create mode 100644 vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b902166a25b3..a5eb16a53976 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -613,8 +613,13 @@ async def chat_completion_stream_generator( previous_text = previous_texts[i] previous_token_ids = all_previous_token_ids[i] current_text = previous_text + delta_text - current_token_ids = previous_token_ids + list( - output.token_ids) + + # avoid the None + list error. + if previous_token_ids: + current_token_ids = previous_token_ids + list( + output.token_ids) + else: + current_token_ids = list(output.token_ids) # handle streaming deltas for tools with named tool_choice if tool_choice_function_name: @@ -1077,9 +1082,17 @@ async def chat_completion_full_generator( else: # FOR NOW make it a chat message; we will have to detect # the type to make it later. + ret_content = content + + # try to use content return from tool parser first, + # tool parser may do some modify for the content. + if (tool_call_info.content + and len(tool_call_info.content) > 0): + ret_content = tool_call_info.content + message = ChatMessage(role=role, reasoning_content=reasoning_content, - content=content) + content=ret_content) # undetermined case that is still important to handle else: diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 218a120a5bb0..137375b9707c 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -6,6 +6,7 @@ from .granite_20b_fc_tool_parser import Granite20bFCToolParser from .granite_tool_parser import GraniteToolParser from .hermes_tool_parser import Hermes2ProToolParser +from .hunyuan_a13b_tool_parser import HunyuanA13BToolParser from .internlm2_tool_parser import Internlm2ToolParser from .jamba_tool_parser import JambaToolParser from .kimi_k2_tool_parser import KimiK2ToolParser @@ -23,5 +24,5 @@ "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser", "Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser", "DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser", - "KimiK2ToolParser" + "KimiK2ToolParser", "HunyuanA13BToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py new file mode 100644 index 000000000000..060baaec1f7b --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py @@ -0,0 +1,448 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501, SIM102 + +import json +from collections.abc import Sequence +from typing import Any, Optional, Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import consume_space +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("hunyuan_a13b") +class HunyuanA13BToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + logger.info("Hunyuan A13B Tool Parser inited.") + # Initialize state for streaming mode + self.prev_tool_calls: list[dict] = [] + self.current_tool_id = -1 + self.current_tool_name_sent = False + self.streamed_args: list[str] = [ + ] # Track arguments sent for each tool + + # For backward compatibility with tests + self.current_tools_sent: list[bool] = [] + + # For backward compatibility with serving code + self.prev_tool_call_arr = [] + + # Regex patterns for preprocessing + self.answer_tool_calls_pattern = re.compile( + r"([\s\S]*?)") + + self.tool_name_reg = re.compile(r'"name"\s*:\s*"([^"]+)"') + + self.tool_empty_arg_reg = re.compile( + r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}') + + self.tool_non_empty_arg_reg = re.compile( + r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})' + ) + + self.bot_string = "" + + # Define streaming state type to be initialized later + self.streaming_state: dict[str, Any] = { + "current_tool_index": -1, + "tool_ids": [], + "sent_tools": [], + } + + def preprocess_model_output( + self, model_output: str) -> tuple[Optional[str], Optional[str]]: + """ + Preprocess the model output to extract content and potential tool calls. + Returns: + Tuple of (content, potential_tool_calls_json) + """ + # Check for with pattern + + answer_tool_calls_match = self.answer_tool_calls_pattern.search( + model_output) + if answer_tool_calls_match: + tool_calls_content = answer_tool_calls_match.group(1).strip() + + content = re.sub(r'.*?', + '', + model_output, + flags=re.DOTALL) + content = content[:content.rfind("") + len("")] \ + if "" in content else content + + try: + json.loads(tool_calls_content) + return content, tool_calls_content + except json.JSONDecodeError: + # If can't parse directly, continue with other patterns + pass + return model_output, None + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract tool calls from a complete model output. + """ + try: + # Preprocess the model output + content, potential_tool_calls = self.preprocess_model_output( + model_output) + + if not potential_tool_calls: + # some text should be filtered out for no function call + # this text is in a13b's chat template. + cleaned_output = content.replace("助手:", "", 1) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=cleaned_output) + + # Parse the potential tool calls as JSON + tool_calls_data = json.loads(potential_tool_calls) + + # Ensure it's an array + if not isinstance(tool_calls_data, list): + logger.debug("Tool calls data is not an array") + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=content or model_output, + ) + + tool_calls: list[ToolCall] = [] + + for idx, call in enumerate(tool_calls_data): + if (not isinstance(call, dict) or "name" not in call + or "arguments" not in call): + logger.debug("Invalid tool call format at index %d", idx) + continue + + tool_call = ToolCall( + id=f"call_{random_uuid()}", + type="function", + function=FunctionCall( + name=call["name"], + arguments=(json.dumps(call["arguments"]) if isinstance( + call["arguments"], dict) else call["arguments"]), + ), + ) + tool_calls.append(tool_call) + + logger.debug("exit in last position.") + return ExtractedToolCallInformation( + tools_called=len(tool_calls) > 0, + tool_calls=tool_calls, + content=content if len(content.strip()) > 0 else None, + ) + + except Exception as e: + logger.exception("Error extracting tool calls: %s", str(e)) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + """ + Extract tool calls for streaming mode. + """ + start_idx = consume_space(0, current_text) + if current_text[start_idx:].startswith(self.bot_string): + start_idx = consume_space(start_idx + len(self.bot_string), + current_text) + + if not current_text or start_idx >= len(current_text) \ + or current_text[start_idx] != '[': + return DeltaMessage(content=delta_text) + + try: + # Initialize streaming state if not exists + # Try parsing as JSON to check for complete tool calls + try: + parsed_tools = json.loads(current_text) + if isinstance(parsed_tools, list): + # Update our tool array for next time + self.prev_tool_call_arr = parsed_tools + except json.JSONDecodeError: + # Not complete JSON yet, use regex for partial parsing + pass + + # This handles the case where tests manually set current_tools_sent + if len(self.current_tools_sent) > 0: + # If current_tools_sent is set to [False], + # it means the test wants us to send the name + if (len(self.current_tools_sent) == 1 + and self.current_tools_sent[0] is False): + # Extract the function name using regex + + name_match = self.tool_name_reg.search(current_text) + if name_match: + function_name = name_match.group(1) + + # The test expects us to send just the name first + tool_id = f"chatcmpl-tool-{random_uuid()}" + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=0, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), # type: ignore + ) + ]) + # Update state to reflect that we've sent the name + self.current_tools_sent = [True] + self.current_tool_id = 0 + self.streaming_state["current_tool_index"] = 0 + if len(self.streaming_state["sent_tools"]) == 0: + self.streaming_state["sent_tools"].append({ + "sent_name": + True, + "sent_arguments_prefix": + False, + "sent_arguments": + "", + }) + else: + self.streaming_state["sent_tools"][0][ + "sent_name"] = True + self.current_tool_name_sent = True + return delta + + # Use regex to identify tool calls in the output + name_matches = list(self.tool_name_reg.finditer(current_text)) + tool_count = len(name_matches) + + # If no tools found yet, return + if tool_count == 0: + return None + + # Ensure our state arrays are large enough + while len(self.streaming_state["sent_tools"]) < tool_count: + self.streaming_state["sent_tools"].append({ + "sent_name": + False, + "sent_arguments_prefix": + False, + "sent_arguments": + "", + }) + + while len(self.streaming_state["tool_ids"]) < tool_count: + self.streaming_state["tool_ids"].append(None) + + # Determine if we need to move to a new tool + current_idx = self.streaming_state["current_tool_index"] + + # If we haven't processed any tool yet or current tool is complete, + # move to next + if current_idx == -1 or current_idx < tool_count - 1: + next_idx = current_idx + 1 + + # If tool at next_idx has not been sent yet + if (next_idx < tool_count + and not self.streaming_state["sent_tools"][next_idx] + ["sent_name"]): + # Update indexes + self.streaming_state["current_tool_index"] = next_idx + self.current_tool_id = ( + next_idx # For backward compatibility + ) + current_idx = next_idx + + # Extract the tool name + tool_name = name_matches[current_idx].group(1) + + # Generate ID and send tool name + tool_id = f"call_{current_idx}_{random_uuid()}" + self.streaming_state["tool_ids"][current_idx] = tool_id + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=current_idx, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=tool_name).model_dump( + exclude_none=True), # type: ignore + ) + ]) + self.streaming_state["sent_tools"][current_idx][ + "sent_name"] = True + self.current_tool_name_sent = ( + True # For backward compatibility + ) + + # Keep track of streamed args for backward compatibility + while len(self.streamed_args) <= current_idx: + self.streamed_args.append("") + + return delta + + # Process arguments for the current tool + if current_idx >= 0 and current_idx < tool_count: + # Support both regular and empty argument objects + # First, check for the empty arguments case: "arguments": {} + empty_args_match = self.tool_empty_arg_reg.search(current_text) + + # Check if this tool has empty arguments + if empty_args_match and empty_args_match.start() > 0: + # Find which tool this empty arguments belongs to + for i in range(tool_count): + if i == current_idx: + # If this is our current tool + # and it has empty arguments + if not self.streaming_state["sent_tools"][ + current_idx]["sent_arguments_prefix"]: + # Send empty object + self.streaming_state["sent_tools"][ + current_idx][ + "sent_arguments_prefix"] = True + self.streaming_state["sent_tools"][ + current_idx]["sent_arguments"] = "{}" + + # Update streamed_args + # for backward compatibility + while len(self.streamed_args) <= current_idx: + self.streamed_args.append("") + self.streamed_args[current_idx] += "{}" + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments="{}"). + model_dump( + exclude_none=True), # type: ignore + ) + ]) + + # Move to next tool if available + if current_idx < tool_count - 1: + self.streaming_state[ + "current_tool_index"] += 1 + self.current_tool_id = self.streaming_state[ + "current_tool_index"] + + return delta + + # Extract arguments for current tool + # using regex for non-empty arguments + args_matches = list( + self.tool_non_empty_arg_reg.finditer(current_text)) + + if current_idx < len(args_matches): + args_text = args_matches[current_idx].group(1) + + # Handle transition between tools + is_last_tool = current_idx == tool_count - 1 + + # Find where the arguments for our current tool end + if not is_last_tool: + # If we have more tools after this one, + # try to find the complete argument block + next_tool_pos = current_text.find( + "},{", args_matches[current_idx].start()) + if next_tool_pos != -1: + args_end_pos = (next_tool_pos + 1 + ) # +1 to include the '}' + args_text = (current_text[args_matches[current_idx] + .start():args_end_pos]. + split('"arguments":')[1].strip()) + + # If arguments haven't been sent yet + sent_args = self.streaming_state["sent_tools"][ + current_idx]["sent_arguments"] + + # If we haven't sent the opening bracket yet + if not self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix"] and args_text.startswith( + "{"): + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix"] = True + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments"] = "{" + + # Update streamed_args for backward compatibility + while len(self.streamed_args) <= current_idx: + self.streamed_args.append("") + self.streamed_args[current_idx] += "{" + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments="{").model_dump( + exclude_none=True), # type: ignore + ) + ]) + return delta + + # If we need to send more arguments + if args_text.startswith(sent_args): + # Calculate what part of arguments we need to send + args_diff = args_text[len(sent_args):] + + if args_diff: + # Update our state + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments"] = args_text + + # Update streamed_args for backward compatibility + while len(self.streamed_args) <= current_idx: + self.streamed_args.append("") + self.streamed_args[current_idx] += args_diff + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments=args_diff).model_dump( + exclude_none=True), # type: ignore + ) + ]) + return delta + + # If the tool's arguments are complete, + # check if we need to move to the next tool + if args_text.endswith("}") and args_text == sent_args: + # This tool is complete, + # move to the next one in the next iteration + if current_idx < tool_count - 1: + self.streaming_state["current_tool_index"] += 1 + self.current_tool_id = self.streaming_state[ + "current_tool_index"] # For compatibility + + # If we got here, we couldn't determine what to stream next + return None + + except Exception as e: + logger.exception("Error in streaming tool calls:", str(e)) + # If we encounter an error, just return the delta text as + # regular content + return DeltaMessage(content=delta_text) diff --git a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py index fb29d51eae8c..b2452b95c1c6 100644 --- a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py +++ b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py @@ -83,6 +83,13 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase): def is_reasoning_end(self, input_ids: list[int]) -> bool: return self.current_state == "response" + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + # for hunyuan streaming reason parsing, the stream parse + # will call first, and the same token will be called in + # is_reasoning_end and extract_content_ids + # this id is not part of content, so just return [] here. + return [] + def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest ) -> tuple[Optional[str], Optional[str]]: From c0fdba0f1401b6e85904b035405a53684d257a85 Mon Sep 17 00:00:00 2001 From: Asher Zhang Date: Fri, 11 Jul 2025 22:03:13 +0800 Subject: [PATCH 2/5] [Doc] Add Hunyuan Tool Calling and Reason part. Signed-off-by: Asher Zhang --- docs/features/reasoning_outputs.md | 1 + docs/features/tool_calling.md | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index 7ab7efd5e765..6b84eca27530 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -14,6 +14,7 @@ vLLM currently supports the following reasoning models: | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ | | [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | | [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` | ✅ | +| [Hunyuan A13B series](https://huggingface.co/collections/tencent/hunyuan-a13b-685ec38e5b46321e3ea7c4be) | `hunyuan_a13b` | `guided_json`, `guided_regex` | ✅ | !!! note IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index f1e5dad35f10..9b9d6e1360e9 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -288,6 +288,16 @@ Supported models: Flags: `--tool-call-parser kimi_k2` +### Hunyuan Models (`hunyuan_a13b`) + +Supported models: + +* `tencent/Hunyuan-A13B-Instruct` (chat template already included huggingface model file.) + +Flags: +* For non-reasoning: `--tool-call-parser hunyuan_a13b` +* For reasoning: `--tool-call-parser hunyuan_a13b --reasoning-parser hunyuan_a13b --enable_reasoning` + ### Models with Pythonic Tool Calls (`pythonic`) A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models. From 066434f06f4f29fa0451b7e66de6d8f8ef7a34ac Mon Sep 17 00:00:00 2001 From: Asher Zhang Date: Sat, 12 Jul 2025 17:05:22 +0800 Subject: [PATCH 3/5] [Model] Hunyuan A13B tool parser refine and tests. - add test for hunyuan a13b tool parser. - fix mypy error on tool parser - refine reason parser test. - refactory tool parser stream function. Signed-off-by: Asher Zhang --- .../test_hunyuan_a13b_tool_parser.py | 153 ++++++ .../test_hunyuan_reasoning_parser.py | 11 + .../tool_parsers/hunyuan_a13b_tool_parser.py | 476 ++++++++---------- 3 files changed, 364 insertions(+), 276 deletions(-) create mode 100644 tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py diff --git a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py new file mode 100644 index 000000000000..bd8e06513e13 --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import json +from unittest.mock import MagicMock + +import pytest + +from tests.entrypoints.openai.tool_parsers.utils import ( + run_tool_extraction, run_tool_extraction_streaming) +from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager + + +def make_tool_call(name, arguments): + return ToolCall(type="function", + function=FunctionCall(name=name, + arguments=json.dumps(arguments))) + + +# TODO: add reason prefix and suffix. + + +@pytest.mark.parametrize( + "model_output,expected_tool_calls,expected_content", + [ + # No tool call + ("How can I help you today?", [], "How can I help you today?"), + # Single tool call, no content + ( + "[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}]", #noqa: E501 + [ + make_tool_call("get_weather", { + "city": "San Francisco", + "metric": "celsius" + }) + ], + None), + # Multiple tool calls + ( + "[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}, {\"name\": \"register_user\", \"arguments\": {\"name\": \"John Doe\", \"age\": 37, \"address\": {\"city\": \"San Francisco\", \"state\": \"CA\"}, \"role\": null, \"passed_test\": true, \"aliases\": [\"John\", \"Johnny\"]}}]", #noqa: E501 + [ + make_tool_call("get_weather", { + "city": "San Francisco", + "metric": "celsius" + }), + make_tool_call( + "register_user", { + "name": "John Doe", + "age": 37, + "address": { + "city": "San Francisco", + "state": "CA" + }, + "role": None, + "passed_test": True, + "aliases": ["John", "Johnny"] + }) + ], + None), + # Content before tool call + ( + "I will call the tool now. [{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Boston\"}}]", #noqa: E501 + [make_tool_call("get_weather", {"city": "Boston"})], + "I will call the tool now. "), + # Content after tool call (should be stripped) + ( + "[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Seattle\"}}]\nThank you!", #noqa: E501 + [make_tool_call("get_weather", {"city": "Seattle"})], + None), + ( + "[{\"name\": \"complex_tool\", \"arguments\": {\"level1\": {\"level2\": {\"level3\": {\"value\": 123}}}}}]", + [ + make_tool_call( + "complex_tool", + {"level1": { + "level2": { + "level3": { + "value": 123 + } + } + }}) + ], + None, + ), + ]) +def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls, + expected_content): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "hunyuan_a13b")(mock_tokenizer) + content, tool_calls = run_tool_extraction(tool_parser, + model_output, + streaming=False) + + # align the random id. + for idx in range(len(tool_calls)): + tool_calls[idx].id = expected_tool_calls[idx].id + assert tool_calls == expected_tool_calls + assert content == expected_content + + +# Streaming test: simulate incremental output +@pytest.mark.parametrize("model_deltas,expected_tool_calls", [ + ([ + "[{\"name\": \"get_weather\", ", + "\"arguments\": {\"city\": \"San Francisco\", ", + "\"metric\": \"celsius\"}}]", "" + ], [ + make_tool_call("get_weather", { + "city": "San Francisco", + "metric": "celsius" + }) + ]), + ([ + "[{\"name\":", " \"get_weather\",", " \"arguments\":", + " {\"city\": \"Boston\"}", "}]", "" + ], [make_tool_call("get_weather", {"city": "Boston"})]), + ([ + "", "[{\"name\":", " \"get_weather\",", " \"arguments\":", + " {\"city\": \"Boston\"}", "}]", "", "\n" + ], [make_tool_call("get_weather", {"city": "Boston"})]), + pytest.param([ + "[{\"name\": \"complex_tool\",", " \"arguments\": ", + " {\"level1\": {\"level2\": ", "{\"level3\": {\"value\": 123}}}}}", + "]" + ], [ + make_tool_call("complex_tool", + {"level1": { + "level2": { + "level3": { + "value": 123 + } + } + }}) + ], + marks=pytest.mark.xfail( + reason="stream parsing not support nested json yet.")), +]) +def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls): + mock_tokenizer = MagicMock() + + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "hunyuan_a13b")(mock_tokenizer) + reconstructor = run_tool_extraction_streaming( + tool_parser, model_deltas, assert_one_tool_per_delta=False) + + # align the random id. + for idx in range(len(reconstructor.tool_calls)): + reconstructor.tool_calls[idx].id = expected_tool_calls[idx].id + + assert reconstructor.tool_calls == expected_tool_calls diff --git a/tests/reasoning/test_hunyuan_reasoning_parser.py b/tests/reasoning/test_hunyuan_reasoning_parser.py index f70cf453f0e9..f9238267f02e 100644 --- a/tests/reasoning/test_hunyuan_reasoning_parser.py +++ b/tests/reasoning/test_hunyuan_reasoning_parser.py @@ -30,6 +30,12 @@ "reasoning_content": "This is a reasoning section", "content": None, } + +COMPLETE_REASONING_WITH_SYMBOL = { + "output": f"{START_REASONING}This is a reasoning section!{START_RESPONSE}", + "reasoning_content": "This is a reasoning section!", + "content": None, +} NO_REASONING = { "output": "This is content", "reasoning_content": None, @@ -70,6 +76,11 @@ COMPLETE_REASONING, id="complete_reasoning", ), + pytest.param( + False, + COMPLETE_REASONING_WITH_SYMBOL, + id="complete_reasoning_with_symbol", + ), pytest.param( False, NO_REASONING, diff --git a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py index 060baaec1f7b..2b65f2579fb4 100644 --- a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py @@ -29,7 +29,6 @@ class HunyuanA13BToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) - logger.info("Hunyuan A13B Tool Parser inited.") # Initialize state for streaming mode self.prev_tool_calls: list[dict] = [] self.current_tool_id = -1 @@ -45,13 +44,14 @@ def __init__(self, tokenizer: AnyTokenizer): # Regex patterns for preprocessing self.answer_tool_calls_pattern = re.compile( - r"([\s\S]*?)") + r"([\s\S]*?)", re.DOTALL) self.tool_name_reg = re.compile(r'"name"\s*:\s*"([^"]+)"') self.tool_empty_arg_reg = re.compile( r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}') + # TODO: not support nested json object in fc arguments. self.tool_non_empty_arg_reg = re.compile( r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})' ) @@ -67,31 +67,22 @@ def __init__(self, tokenizer: AnyTokenizer): def preprocess_model_output( self, model_output: str) -> tuple[Optional[str], Optional[str]]: - """ - Preprocess the model output to extract content and potential tool calls. - Returns: - Tuple of (content, potential_tool_calls_json) - """ - # Check for with pattern - - answer_tool_calls_match = self.answer_tool_calls_pattern.search( - model_output) - if answer_tool_calls_match: - tool_calls_content = answer_tool_calls_match.group(1).strip() - - content = re.sub(r'.*?', - '', - model_output, - flags=re.DOTALL) - content = content[:content.rfind("") + len("")] \ - if "" in content else content - - try: - json.loads(tool_calls_content) - return content, tool_calls_content - except json.JSONDecodeError: - # If can't parse directly, continue with other patterns - pass + # find the location tool call + for match in self.answer_tool_calls_pattern.finditer(model_output): + start, end = match.span() + # check tool_calls whether in side of + think_regions = [(m.start(), m.end()) for m in re.finditer( + r"(.*?)", model_output, flags=re.DOTALL)] + in_think = any(start > t_start and end < t_end + for t_start, t_end in think_regions) + if not in_think: + content = model_output[:start] + tool_calls_content = match.group(1).strip() + try: + json.loads(tool_calls_content) + return content, tool_calls_content + except Exception: + continue return model_output, None def extract_tool_calls( @@ -108,10 +99,11 @@ def extract_tool_calls( if not potential_tool_calls: # some text should be filtered out for no function call # this text is in a13b's chat template. - cleaned_output = content.replace("助手:", "", 1) + if content: + content = content.replace("助手:", "", 1) return ExtractedToolCallInformation(tools_called=False, tool_calls=[], - content=cleaned_output) + content=content) # Parse the potential tool calls as JSON tool_calls_data = json.loads(potential_tool_calls) @@ -130,7 +122,6 @@ def extract_tool_calls( for idx, call in enumerate(tool_calls_data): if (not isinstance(call, dict) or "name" not in call or "arguments" not in call): - logger.debug("Invalid tool call format at index %d", idx) continue tool_call = ToolCall( @@ -144,15 +135,17 @@ def extract_tool_calls( ) tool_calls.append(tool_call) - logger.debug("exit in last position.") + if not content or len(content.strip()) == 0: + # clear the whitespace content. + content = None + return ExtractedToolCallInformation( tools_called=len(tool_calls) > 0, tool_calls=tool_calls, - content=content if len(content.strip()) > 0 else None, + content=content, ) - except Exception as e: - logger.exception("Error extracting tool calls: %s", str(e)) + except Exception: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) @@ -170,279 +163,210 @@ def extract_tool_calls_streaming( """ Extract tool calls for streaming mode. """ + start_idx = consume_space(0, current_text) if current_text[start_idx:].startswith(self.bot_string): start_idx = consume_space(start_idx + len(self.bot_string), current_text) - - if not current_text or start_idx >= len(current_text) \ - or current_text[start_idx] != '[': + if not current_text or start_idx >= len( + current_text) or current_text[start_idx] != '[': return DeltaMessage(content=delta_text) - try: - # Initialize streaming state if not exists - # Try parsing as JSON to check for complete tool calls - try: - parsed_tools = json.loads(current_text) - if isinstance(parsed_tools, list): - # Update our tool array for next time - self.prev_tool_call_arr = parsed_tools - except json.JSONDecodeError: - # Not complete JSON yet, use regex for partial parsing - pass - - # This handles the case where tests manually set current_tools_sent - if len(self.current_tools_sent) > 0: - # If current_tools_sent is set to [False], - # it means the test wants us to send the name - if (len(self.current_tools_sent) == 1 - and self.current_tools_sent[0] is False): - # Extract the function name using regex - - name_match = self.tool_name_reg.search(current_text) - if name_match: - function_name = name_match.group(1) - - # The test expects us to send just the name first - tool_id = f"chatcmpl-tool-{random_uuid()}" - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=0, - type="function", - id=tool_id, - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True), # type: ignore - ) - ]) - # Update state to reflect that we've sent the name - self.current_tools_sent = [True] - self.current_tool_id = 0 - self.streaming_state["current_tool_index"] = 0 - if len(self.streaming_state["sent_tools"]) == 0: - self.streaming_state["sent_tools"].append({ - "sent_name": - True, - "sent_arguments_prefix": - False, - "sent_arguments": - "", - }) - else: - self.streaming_state["sent_tools"][0][ - "sent_name"] = True - self.current_tool_name_sent = True - return delta + self._try_parse_json_tools(current_text[start_idx:]) - # Use regex to identify tool calls in the output - name_matches = list(self.tool_name_reg.finditer(current_text)) - tool_count = len(name_matches) - - # If no tools found yet, return - if tool_count == 0: - return None - - # Ensure our state arrays are large enough - while len(self.streaming_state["sent_tools"]) < tool_count: - self.streaming_state["sent_tools"].append({ - "sent_name": - False, - "sent_arguments_prefix": - False, - "sent_arguments": - "", - }) - - while len(self.streaming_state["tool_ids"]) < tool_count: - self.streaming_state["tool_ids"].append(None) - - # Determine if we need to move to a new tool - current_idx = self.streaming_state["current_tool_index"] - - # If we haven't processed any tool yet or current tool is complete, - # move to next - if current_idx == -1 or current_idx < tool_count - 1: - next_idx = current_idx + 1 - - # If tool at next_idx has not been sent yet - if (next_idx < tool_count - and not self.streaming_state["sent_tools"][next_idx] - ["sent_name"]): - # Update indexes - self.streaming_state["current_tool_index"] = next_idx - self.current_tool_id = ( - next_idx # For backward compatibility - ) - current_idx = next_idx + test_delta = self._handle_test_compatibility(current_text) + if test_delta: + return test_delta + + name_matches = list(self.tool_name_reg.finditer(current_text)) + tool_count = len(name_matches) + if tool_count == 0: + return None + self._ensure_state_arrays(tool_count) + current_idx = self.streaming_state["current_tool_index"] - # Extract the tool name - tool_name = name_matches[current_idx].group(1) + name_delta = self._handle_tool_name_streaming(current_idx, tool_count, + name_matches) + if name_delta: + return name_delta - # Generate ID and send tool name - tool_id = f"call_{current_idx}_{random_uuid()}" - self.streaming_state["tool_ids"][current_idx] = tool_id + args_delta = self._handle_tool_args_streaming(current_text, + current_idx, tool_count) + if args_delta: + return args_delta + return None + + def _try_parse_json_tools(self, current_text: str): + try: + parsed_tools = json.loads(current_text) + if isinstance(parsed_tools, list): + self.prev_tool_call_arr = parsed_tools + except json.JSONDecodeError: + pass + + def _handle_test_compatibility(self, current_text: str): + if len(self.current_tools_sent) > 0: + if (len(self.current_tools_sent) == 1 + and self.current_tools_sent[0] is False): + name_match = self.tool_name_reg.search(current_text) + if name_match: + function_name = name_match.group(1) + tool_id = f"chatcmpl-tool-{random_uuid()}" delta = DeltaMessage(tool_calls=[ DeltaToolCall( - index=current_idx, + index=0, type="function", id=tool_id, function=DeltaFunctionCall( - name=tool_name).model_dump( - exclude_none=True), # type: ignore + name=function_name).model_dump( + exclude_none=True), ) ]) - self.streaming_state["sent_tools"][current_idx][ - "sent_name"] = True - self.current_tool_name_sent = ( - True # For backward compatibility + self.current_tools_sent = [True] + self.current_tool_id = 0 + self.streaming_state["current_tool_index"] = 0 + if len(self.streaming_state["sent_tools"]) == 0: + self.streaming_state["sent_tools"].append({ + "sent_name": + True, + "sent_arguments_prefix": + False, + "sent_arguments": + "", + }) + else: + self.streaming_state["sent_tools"][0][ + "sent_name"] = True + self.current_tool_name_sent = True + return delta + return None + + def _ensure_state_arrays(self, tool_count: int): + while len(self.streaming_state["sent_tools"]) < tool_count: + self.streaming_state["sent_tools"].append({ + "sent_name": False, + "sent_arguments_prefix": False, + "sent_arguments": "", + }) + while len(self.streaming_state["tool_ids"]) < tool_count: + self.streaming_state["tool_ids"].append(None) + + def _handle_tool_name_streaming(self, current_idx: int, tool_count: int, + name_matches): + if current_idx == -1 or current_idx < tool_count - 1: + next_idx = current_idx + 1 + if (next_idx < tool_count + and not self.streaming_state["sent_tools"][next_idx] + ["sent_name"]): + self.streaming_state["current_tool_index"] = next_idx + self.current_tool_id = next_idx + current_idx = next_idx + tool_name = name_matches[current_idx].group(1) + tool_id = f"call_{current_idx}_{random_uuid()}" + self.streaming_state["tool_ids"][current_idx] = tool_id + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=current_idx, + type="function", + id=tool_id, + function=DeltaFunctionCall(name=tool_name).model_dump( + exclude_none=True), ) + ]) + self.streaming_state["sent_tools"][current_idx][ + "sent_name"] = True + self.current_tool_name_sent = True + while len(self.streamed_args) <= current_idx: + self.streamed_args.append("") + return delta + return None + + def _handle_tool_args_streaming(self, current_text: str, current_idx: int, + tool_count: int): + + if current_idx >= 0 and current_idx < tool_count: + empty_args_match = self.tool_empty_arg_reg.search(current_text) + if empty_args_match and empty_args_match.start() > 0: + for i in range(tool_count): + if i == current_idx: + if not self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix"]: + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix"] = True + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments"] = "{}" + while len(self.streamed_args) <= current_idx: + self.streamed_args.append("") + self.streamed_args[current_idx] += "{}" + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments="{}").model_dump( + exclude_none=True), + ) + ]) + if current_idx < tool_count - 1: + self.streaming_state["current_tool_index"] += 1 + self.current_tool_id = self.streaming_state[ + "current_tool_index"] + return delta - # Keep track of streamed args for backward compatibility + args_matches = list( + self.tool_non_empty_arg_reg.finditer(current_text)) + if current_idx < len(args_matches): + args_text = args_matches[current_idx].group(1) + is_last_tool = current_idx == tool_count - 1 + if not is_last_tool: + next_tool_pos = current_text.find( + "},{", args_matches[current_idx].start()) + if next_tool_pos != -1: + args_end_pos = (next_tool_pos + 1) + args_text = ( + current_text[args_matches[current_idx].start( + ):args_end_pos].split('"arguments":')[1].strip()) + sent_args = self.streaming_state["sent_tools"][current_idx][ + "sent_arguments"] + if not self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix"] and args_text.startswith("{"): + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix"] = True + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments"] = "{" while len(self.streamed_args) <= current_idx: self.streamed_args.append("") - + self.streamed_args[current_idx] += "{" + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments="{").model_dump(exclude_none=True), + ) + ]) return delta - # Process arguments for the current tool - if current_idx >= 0 and current_idx < tool_count: - # Support both regular and empty argument objects - # First, check for the empty arguments case: "arguments": {} - empty_args_match = self.tool_empty_arg_reg.search(current_text) - - # Check if this tool has empty arguments - if empty_args_match and empty_args_match.start() > 0: - # Find which tool this empty arguments belongs to - for i in range(tool_count): - if i == current_idx: - # If this is our current tool - # and it has empty arguments - if not self.streaming_state["sent_tools"][ - current_idx]["sent_arguments_prefix"]: - # Send empty object - self.streaming_state["sent_tools"][ - current_idx][ - "sent_arguments_prefix"] = True - self.streaming_state["sent_tools"][ - current_idx]["sent_arguments"] = "{}" - - # Update streamed_args - # for backward compatibility - while len(self.streamed_args) <= current_idx: - self.streamed_args.append("") - self.streamed_args[current_idx] += "{}" - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - function=DeltaFunctionCall( - arguments="{}"). - model_dump( - exclude_none=True), # type: ignore - ) - ]) - - # Move to next tool if available - if current_idx < tool_count - 1: - self.streaming_state[ - "current_tool_index"] += 1 - self.current_tool_id = self.streaming_state[ - "current_tool_index"] - - return delta - - # Extract arguments for current tool - # using regex for non-empty arguments - args_matches = list( - self.tool_non_empty_arg_reg.finditer(current_text)) - - if current_idx < len(args_matches): - args_text = args_matches[current_idx].group(1) - - # Handle transition between tools - is_last_tool = current_idx == tool_count - 1 - - # Find where the arguments for our current tool end - if not is_last_tool: - # If we have more tools after this one, - # try to find the complete argument block - next_tool_pos = current_text.find( - "},{", args_matches[current_idx].start()) - if next_tool_pos != -1: - args_end_pos = (next_tool_pos + 1 - ) # +1 to include the '}' - args_text = (current_text[args_matches[current_idx] - .start():args_end_pos]. - split('"arguments":')[1].strip()) - - # If arguments haven't been sent yet - sent_args = self.streaming_state["sent_tools"][ - current_idx]["sent_arguments"] - - # If we haven't sent the opening bracket yet - if not self.streaming_state["sent_tools"][current_idx][ - "sent_arguments_prefix"] and args_text.startswith( - "{"): + if args_text.startswith(sent_args): + args_diff = args_text[len(sent_args):] + if args_diff: self.streaming_state["sent_tools"][current_idx][ - "sent_arguments_prefix"] = True - self.streaming_state["sent_tools"][current_idx][ - "sent_arguments"] = "{" - - # Update streamed_args for backward compatibility + "sent_arguments"] = args_text while len(self.streamed_args) <= current_idx: self.streamed_args.append("") - self.streamed_args[current_idx] += "{" - + self.streamed_args[current_idx] += args_diff delta = DeltaMessage(tool_calls=[ DeltaToolCall( index=current_idx, function=DeltaFunctionCall( - arguments="{").model_dump( - exclude_none=True), # type: ignore + arguments=args_diff).model_dump( + exclude_none=True), ) ]) return delta - # If we need to send more arguments - if args_text.startswith(sent_args): - # Calculate what part of arguments we need to send - args_diff = args_text[len(sent_args):] - - if args_diff: - # Update our state - self.streaming_state["sent_tools"][current_idx][ - "sent_arguments"] = args_text - - # Update streamed_args for backward compatibility - while len(self.streamed_args) <= current_idx: - self.streamed_args.append("") - self.streamed_args[current_idx] += args_diff - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - function=DeltaFunctionCall( - arguments=args_diff).model_dump( - exclude_none=True), # type: ignore - ) - ]) - return delta - - # If the tool's arguments are complete, - # check if we need to move to the next tool - if args_text.endswith("}") and args_text == sent_args: - # This tool is complete, - # move to the next one in the next iteration - if current_idx < tool_count - 1: - self.streaming_state["current_tool_index"] += 1 - self.current_tool_id = self.streaming_state[ - "current_tool_index"] # For compatibility - - # If we got here, we couldn't determine what to stream next - return None - - except Exception as e: - logger.exception("Error in streaming tool calls:", str(e)) - # If we encounter an error, just return the delta text as - # regular content - return DeltaMessage(content=delta_text) + if args_text.endswith("}") and args_text == sent_args: + if current_idx < tool_count - 1: + self.streaming_state["current_tool_index"] += 1 + self.current_tool_id = self.streaming_state[ + "current_tool_index"] + return None From 110f01b257a9b45050df9fff5345614c2a2ded69 Mon Sep 17 00:00:00 2001 From: Asher Zhang Date: Mon, 14 Jul 2025 20:30:02 +0800 Subject: [PATCH 4/5] [Model] Hunyuan A13B : add function call template Signed-off-by: Asher Zhang --- .../tool_chat_template_hunyuan_a13b.jinja | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 examples/tool_chat_template_hunyuan_a13b.jinja diff --git a/examples/tool_chat_template_hunyuan_a13b.jinja b/examples/tool_chat_template_hunyuan_a13b.jinja new file mode 100644 index 000000000000..a0808e44858a --- /dev/null +++ b/examples/tool_chat_template_hunyuan_a13b.jinja @@ -0,0 +1,113 @@ +{% set loop_messages = messages %} +{% if tools %} + {% set weekday_map = {'Monday': '星期一', 'Tuesday': '星期二', 'Wednesday': '星期三', 'Thursday': '星期四', 'Friday': '星期五', 'Saturday': '星期六', 'Sunday': '星期日'} %} + {% set weekday_cn = weekday_map[strftime_now('%A')] %} + {% set datetime_str = strftime_now('%Y-%m-%d %H:%M:%S') %} + {% set datetime_str = datetime_str + ' ' + weekday_cn %} + {% for message in loop_messages %} + {% if 'content' in message %} + {% set content = message['content'] %} + {% else %} + {% set content = '' %} + {% endif %} + {% if loop.index0 == 0 %} + {% set content_tmp = '你是一位函数组合专家。你会得到一个问题和一组可能的函数。根据问题,你需要进行一个或多个函数/工具调用以实现目的。 +如果没有一个函数可以使用,请直接使用自然语言回复用户,以助手:开头。 +如果给定的问题缺少函数所需的参数,请使用自然语言进行提问,向用户询问必要信息,以助手:开头。 +如果调用结果已经足够回答用户问题,请对历史结果进行总结,使用自然语言回复用户,以助手:开头。 +你应该只在工具调用部分返回函数调用。如果你决定调用任何函数,你必须将其格式化为[{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},...]。你不应该在回复中包含任何其他文本。以下是你可以调用的函数列表,格式为JSON。 +' %} + {% set content_tmp = content_tmp + ' +' + tools | tojson + ' +' %} + {% if message['role'] == 'system' %} + {% set content_tmp = content_tmp + ' +额外要求: +' + content + ' + +如果你决定返回函数调用,请将其格式化为[{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},...],不得包含其他文本。如果额外要求里有格式要求,请忽略,以此处为准。 +否则,请参考开头说的三种情况,以助手:开头进行回复。 + +如果额外要求里有时间信息,就以额外要求里的时间为准,否则,参考当前时间:' + datetime_str %} + {% set content = '<|startoftext|>' + content_tmp + '<|extra_4|>' %} + {% elif message['role'] == 'user' %} + {% set content_tmp = content_tmp + ' +如果你决定返回函数调用,请将其格式化为[{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},...],不得包含其他文本。 +否则,请参考开头说的三种情况,以助手:开头进行回复。 + +当前时间:' + datetime_str %} + {% set content_tmp = '<|startoftext|>' + content_tmp + '<|extra_4|>'%} + {% set content = content_tmp + '用户:' + content + '<|extra_0|>' %} + {% endif %} + {% else %} + {% if message['role'] == 'user' %} + {% set content = '用户:' + content + '<|extra_0|>' %} + {% elif message['role'] == 'assistant' %} + {% if 'tool_calls' in message %} + {% set tool_calls = message['tool_calls'] %} + {% set ns = namespace(tool_calls="[") %} + {% for tool_call in tool_calls %} + {% set function = tool_call['function'] %} + {% set name = function['name'] %} + {% set ns.tool_calls = ns.tool_calls + '{"name": "' + name + '", '%} + {% set arguments = function['arguments'] %} + {% if arguments is not string %} + {% set arguments = arguments | tojson %} + {% endif %} + {% set ns.tool_calls = ns.tool_calls + '"arguments": ' + arguments + '}' %} + {% if not loop.last %} + {% set ns.tool_calls = ns.tool_calls + ', '%} + {% endif %} + {% endfor %} + {% set ns.tool_calls = ns.tool_calls + ']' %} + {% set content = content + '' + ns.tool_calls + '' %} + {% else %} + {% set content = '助手:' + content %} + {% endif %} + {% set content = content + '<|eos|>' %} + {% elif message['role'] == 'tool' %} + {% if content is not string %} + {set content = content | tojson } + {% endif %} + {% set content = '' + content + '' %} + {% set content = content + '<|extra_0|>' %} + {% endif %} + {% endif %} + {{- content -}} + {% endfor %} +{% else %} + {% set context = {'has_head': true} %} + {% for message in loop_messages %} + {% if 'content' in message %} + {% set content = message['content'] %} + {% else %} + {% set content = '' %} + {% endif %} + {% if loop.index0 == 0 %} + {% if content == '' %} + {% set _ = context.update({'has_head': false}) %} + {% elif message['role'] == 'system' %} + {% set content = '<|startoftext|>' + content + '<|extra_4|>' %} + {% endif %} + {% endif %} + {% if message['role'] == 'user' %} + {% if loop.index0 == 1 and not context.has_head %} + {% set content = '<|startoftext|>' + content %} + {% endif %} + {% if loop.index0 == 1 and context.has_head %} + {% set content = content + '<|extra_0|>' %} + {% else %} + {% set content = '<|startoftext|>' + content + '<|extra_0|>' %} + {% endif %} + {% elif message['role'] == 'assistant' %} + {% set content = content + '<|eos|>' %} + {% elif message['role'] == 'tool' %} + {% set content = content + '<|extra_0|>' %} + {% endif %} + {{- content -}} + {% endfor %} +{% endif %} +{%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n' }} +{%- endif %} + From af5c48ac6b117369299f30599ed9d72e2c41155e Mon Sep 17 00:00:00 2001 From: Asher Zhang Date: Mon, 14 Jul 2025 20:32:03 +0800 Subject: [PATCH 5/5] [Model] Add moe config for Hunyuan A13B and benchmark support - tune fused moe config. - benchmark: add hunyuan in moe benchmark Signed-off-by: Asher Zhang --- benchmarks/kernels/benchmark_moe.py | 5 + ...device_name=NVIDIA_H20,dtype=fp8_w8a8.json | 146 ++++++++++++++++++ ...device_name=NVIDIA_H20,dtype=fp8_w8a8.json | 146 ++++++++++++++++++ .../E=64,N=3072,device_name=NVIDIA_H20.json | 146 ++++++++++++++++++ ...device_name=NVIDIA_H20,dtype=fp8_w8a8.json | 146 ++++++++++++++++++ .../E=64,N=384,device_name=NVIDIA_H20.json | 146 ++++++++++++++++++ ...device_name=NVIDIA_H20,dtype=fp8_w8a8.json | 146 ++++++++++++++++++ .../E=64,N=768,device_name=NVIDIA_H20.json | 146 ++++++++++++++++++ 8 files changed, 1027 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20.json diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 51c9f68e43af..132c325ce591 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -586,6 +586,11 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"): + E = config.num_experts + topk = config.moe_topk[0] + intermediate_size = config.moe_intermediate_size[0] + shard_intermediate_size = 2 * intermediate_size // args.tp_size else: # Support for llama4 config = config.get_text_config() diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..298a36175e60 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..0e210cb0f38d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20.json new file mode 100644 index 000000000000..e4fa1e2e6e9b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..082456d319d3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20.json new file mode 100644 index 000000000000..c3b2e7fa91eb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..bba1d21aa2b6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20.json new file mode 100644 index 000000000000..de1c413b6e1a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +}