Skip to content

[Feature] Add command tool parser for Command-A model #20633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/tool_parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from .abstract_tool_parser import ToolParser, ToolParserManager
from .command_tool_parser import CommandToolParser
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
from .granite_tool_parser import GraniteToolParser
Expand All @@ -23,5 +24,5 @@
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
"Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
"DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser",
"KimiK2ToolParser"
"CommandToolParser", "KimiK2ToolParser"
]
151 changes: 151 additions & 0 deletions vllm/entrypoints/openai/tool_parsers/command_tool_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
from collections.abc import Sequence
from typing import Union

import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.utils import random_uuid

logger = init_logger(__name__)


@ToolParserManager.register_module("command")
class CommandToolParser(ToolParser):

def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# Streaming state
self.prev_tool_call_arr: list[dict] = []
self.streamed_args_for_tool: list[str] = []
self.current_tool_id: int = -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

current_tool_id should be initialized to 0 instead of -1. This is part of a fix for a bug that causes non-contiguous tool call indices in streaming mode.

Suggested change
self.current_tool_id: int = -1
self.current_tool_id: int = 0

self.current_tool_name_sent: bool = False
Comment on lines +32 to +35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The instance variables prev_tool_call_arr, streamed_args_for_tool, and current_tool_name_sent are initialized but appear to be unused within the class. To improve code clarity and maintainability, they should be removed.


# Action delimiters
self.tool_call_start_token = "<|START_ACTION|>"
self.tool_call_end_token = "<|END_ACTION|>"
self.tool_call_regex = re.compile(
r"<\|START_ACTION\|>(.*?)<\|END_ACTION\|>", re.DOTALL)

# Precompute token ids
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if (self.tool_call_start_token_id is None
or self.tool_call_end_token_id is None):
raise RuntimeError(
"CommandToolParser cannot find start/end tokens in vocab")

def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
# Synchronous parsing: look for full action block
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
match = self.tool_call_regex.search(model_output)
if not match:
raise ValueError("No action block found")
payload = match.group(1)
raw_calls = json.loads(payload)
tool_calls = []
for entry in raw_calls:
name = entry.get("tool_name")
params = entry.get("parameters", {})
tool_calls.append(
ToolCall(type="function",
function=FunctionCall(name=name,
arguments=json.dumps(
params,
ensure_ascii=False))))
# content before action
prefix = model_output.split(self.tool_call_start_token, 1)[0]
return ExtractedToolCallInformation(tools_called=True,
tool_calls=tool_calls,
content=prefix or None)
except Exception:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Catching a broad Exception can hide unexpected errors and make debugging more difficult. It's better to catch more specific exceptions that you expect to handle, such as json.JSONDecodeError and ValueError.

Suggested change
except Exception:
except (json.JSONDecodeError, ValueError):

logger.exception("Error extracting sync tool calls")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)

def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:

prev_start = previous_token_ids.count(self.tool_call_start_token_id)
cur_start = current_token_ids.count(self.tool_call_start_token_id)
cur_end = current_token_ids.count(self.tool_call_end_token_id)
Comment on lines +98 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The count() method is called on previous_token_ids and current_token_ids in each invocation of this streaming method. Since these lists can grow large, this is inefficient as it re-scans the entire list every time. Consider maintaining the counts as part of the parser's state and updating them incrementally with delta_token_ids to improve performance.


# Case 1: Block not started → Text as is
if cur_start == 0:
return DeltaMessage(content=delta_text)

# Case 2: Starting a new block
if cur_start > prev_start:
self.current_tool_id += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Incrementing self.current_tool_id here when a new action block starts (<|START_ACTION|>) is incorrect and causes non-contiguous tool call indices. The index should only be incremented after a complete tool call has been parsed. This line should be removed.

return None

# Case 3: Inside block, not closed → ignored
if cur_start > cur_end:
return None

# Case 4: Block End Point
if cur_start == cur_end and self.tool_call_end_token in delta_text:
full = current_text + delta_text

payload = full.split(self.tool_call_start_token, 1)[1] \
.split(self.tool_call_end_token, 1)[0].strip()
try:
calls = partial_json_parser.loads(payload or "[]", Allow.ALL)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug("Waiting for complete JSON")
return None
except json.JSONDecodeError:
logger.debug("Malformed JSON payload: %s", payload)
return None

calls_list = calls if isinstance(calls, list) else [calls]
deltas = []
for entry in calls_list:
name = entry.get("tool_name")
params = entry.get("parameters", {})
args = json.dumps(params, ensure_ascii=False)
deltas.append(
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=name,
arguments=args,
).model_dump(exclude_none=True),
))

self.current_tool_id += 1

return DeltaMessage(tool_calls=deltas)

return DeltaMessage(content=delta_text)
2 changes: 1 addition & 1 deletion vllm/reasoning/hunyuan_a13b_reasoning_parser.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import re
from collections.abc import Sequence
from typing import Optional, Union

import regex as re
from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
Expand Down
Loading