Skip to content

Commit 8dc9555

Browse files
Ubuntugjgjos
authored andcommitted
[Feature] Add Command Tool Parser
Signed-off-by: <> Signed-off-by: Ubuntu <[email protected]> Signed-off-by: Doil Kim <[email protected]>
1 parent b942c09 commit 8dc9555

File tree

2 files changed

+154
-1
lines changed

2 files changed

+154
-1
lines changed

vllm/entrypoints/openai/tool_parsers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from .abstract_tool_parser import ToolParser, ToolParserManager
5+
from .command_tool_parser import CommandJsonToolParser
56
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
67
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
78
from .granite_tool_parser import GraniteToolParser
@@ -21,5 +22,6 @@
2122
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
2223
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
2324
"Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
24-
"DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser"
25+
"DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser",
26+
"CommandJsonToolParser"
2527
]
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import json
5+
from collections.abc import Sequence
6+
from typing import Union
7+
8+
import partial_json_parser
9+
import regex as re
10+
from partial_json_parser.core.options import Allow
11+
from transformers import PreTrainedTokenizerBase
12+
13+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
14+
DeltaFunctionCall, DeltaMessage,
15+
DeltaToolCall,
16+
ExtractedToolCallInformation,
17+
FunctionCall, ToolCall)
18+
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
19+
ToolParser, ToolParserManager)
20+
from vllm.logger import init_logger
21+
from vllm.utils import random_uuid
22+
23+
logger = init_logger(__name__)
24+
25+
26+
@ToolParserManager.register_module("command_json")
27+
class CommandJsonToolParser(ToolParser):
28+
29+
def __init__(self, tokenizer: PreTrainedTokenizerBase):
30+
super().__init__(tokenizer)
31+
# Streaming state
32+
self.prev_tool_call_arr: list[dict] = []
33+
self.streamed_args_for_tool: list[str] = []
34+
self.current_tool_id: int = -1
35+
self.current_tool_name_sent: bool = False
36+
37+
# Action delimiters
38+
self.tool_call_start_token = "<|START_ACTION|>"
39+
self.tool_call_end_token = "<|END_ACTION|>"
40+
self.tool_call_regex = re.compile(
41+
r"<\|START_ACTION\|>(.*?)<\|END_ACTION\|>", re.DOTALL)
42+
43+
# Precompute token ids
44+
self.tool_call_start_token_id = self.vocab.get(
45+
self.tool_call_start_token)
46+
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
47+
if (self.tool_call_start_token_id is None
48+
or self.tool_call_end_token_id is None):
49+
raise RuntimeError(
50+
"CommandJsonToolParser cannot find start/end tokens in vocab")
51+
52+
def extract_tool_calls(
53+
self, model_output: str,
54+
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
55+
# Synchronous parsing: look for full action block
56+
if self.tool_call_start_token not in model_output:
57+
return ExtractedToolCallInformation(tools_called=False,
58+
tool_calls=[],
59+
content=model_output)
60+
try:
61+
match = self.tool_call_regex.search(model_output)
62+
if not match:
63+
raise ValueError("No action block found")
64+
payload = match.group(1)
65+
raw_calls = json.loads(payload)
66+
tool_calls = []
67+
for entry in raw_calls:
68+
name = entry.get("tool_name")
69+
params = entry.get("parameters", {})
70+
tool_calls.append(
71+
ToolCall(type="function",
72+
function=FunctionCall(name=name,
73+
arguments=json.dumps(
74+
params,
75+
ensure_ascii=False))))
76+
# content before action
77+
prefix = model_output.split(self.tool_call_start_token, 1)[0]
78+
return ExtractedToolCallInformation(tools_called=True,
79+
tool_calls=tool_calls,
80+
content=prefix or None)
81+
except Exception:
82+
logger.exception("Error extracting sync tool calls")
83+
return ExtractedToolCallInformation(tools_called=False,
84+
tool_calls=[],
85+
content=model_output)
86+
87+
def extract_tool_calls_streaming(
88+
self,
89+
previous_text: str,
90+
current_text: str,
91+
delta_text: str,
92+
previous_token_ids: Sequence[int],
93+
current_token_ids: Sequence[int],
94+
delta_token_ids: Sequence[int],
95+
request: ChatCompletionRequest,
96+
) -> Union[DeltaMessage, None]:
97+
98+
prev_start = previous_token_ids.count(self.tool_call_start_token_id)
99+
cur_start = current_token_ids.count(self.tool_call_start_token_id)
100+
cur_end = current_token_ids.count(self.tool_call_end_token_id)
101+
102+
# Case 1: Block not started → Text as is
103+
if cur_start == 0:
104+
return DeltaMessage(content=delta_text)
105+
106+
# Case 2: Starting a new block
107+
if cur_start > prev_start:
108+
self.current_tool_id += 1
109+
return None
110+
111+
# Case 3: Inside block, not closed → ignored
112+
if cur_start > cur_end:
113+
return None
114+
115+
# Case 4: Block End Point
116+
if cur_start == cur_end and self.tool_call_end_token in delta_text:
117+
full = current_text + delta_text
118+
119+
payload = full.split(self.tool_call_start_token, 1)[1] \
120+
.split(self.tool_call_end_token, 1)[0].strip()
121+
try:
122+
calls = partial_json_parser.loads(payload or "[]", Allow.ALL)
123+
except partial_json_parser.core.exceptions.MalformedJSON:
124+
logger.debug("Waiting for complete JSON")
125+
return None
126+
except json.JSONDecodeError:
127+
logger.debug("Malformed JSON payload: %s", payload)
128+
return None
129+
130+
calls_list = calls if isinstance(calls, list) else [calls]
131+
deltas = []
132+
for entry in calls_list:
133+
name = entry.get("tool_name")
134+
params = entry.get("parameters", {})
135+
args = json.dumps(params, ensure_ascii=False)
136+
deltas.append(
137+
DeltaToolCall(
138+
index=self.current_tool_id,
139+
type="function",
140+
id=f"chatcmpl-tool-{random_uuid()}",
141+
function=DeltaFunctionCall(
142+
name=name,
143+
arguments=args,
144+
).model_dump(exclude_none=True),
145+
))
146+
147+
self.current_tool_id += 1
148+
149+
return DeltaMessage(tool_calls=deltas)
150+
151+
return DeltaMessage(content=delta_text)

0 commit comments

Comments
 (0)