|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +import json |
| 5 | +from collections.abc import Sequence |
| 6 | +from typing import Union |
| 7 | + |
| 8 | +import partial_json_parser |
| 9 | +import regex as re |
| 10 | +from partial_json_parser.core.options import Allow |
| 11 | +from transformers import PreTrainedTokenizerBase |
| 12 | + |
| 13 | +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, |
| 14 | + DeltaFunctionCall, DeltaMessage, |
| 15 | + DeltaToolCall, |
| 16 | + ExtractedToolCallInformation, |
| 17 | + FunctionCall, ToolCall) |
| 18 | +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( |
| 19 | + ToolParser, ToolParserManager) |
| 20 | +from vllm.logger import init_logger |
| 21 | +from vllm.utils import random_uuid |
| 22 | + |
| 23 | +logger = init_logger(__name__) |
| 24 | + |
| 25 | + |
| 26 | +@ToolParserManager.register_module("command_json") |
| 27 | +class CommandJsonToolParser(ToolParser): |
| 28 | + |
| 29 | + def __init__(self, tokenizer: PreTrainedTokenizerBase): |
| 30 | + super().__init__(tokenizer) |
| 31 | + # Streaming state |
| 32 | + self.prev_tool_call_arr: list[dict] = [] |
| 33 | + self.streamed_args_for_tool: list[str] = [] |
| 34 | + self.current_tool_id: int = -1 |
| 35 | + self.current_tool_name_sent: bool = False |
| 36 | + |
| 37 | + # Action delimiters |
| 38 | + self.tool_call_start_token = "<|START_ACTION|>" |
| 39 | + self.tool_call_end_token = "<|END_ACTION|>" |
| 40 | + self.tool_call_regex = re.compile( |
| 41 | + r"<\|START_ACTION\|>(.*?)<\|END_ACTION\|>", re.DOTALL) |
| 42 | + |
| 43 | + # Precompute token ids |
| 44 | + self.tool_call_start_token_id = self.vocab.get( |
| 45 | + self.tool_call_start_token) |
| 46 | + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) |
| 47 | + if (self.tool_call_start_token_id is None |
| 48 | + or self.tool_call_end_token_id is None): |
| 49 | + raise RuntimeError( |
| 50 | + "CommandJsonToolParser cannot find start/end tokens in vocab") |
| 51 | + |
| 52 | + def extract_tool_calls( |
| 53 | + self, model_output: str, |
| 54 | + request: ChatCompletionRequest) -> ExtractedToolCallInformation: |
| 55 | + # Synchronous parsing: look for full action block |
| 56 | + if self.tool_call_start_token not in model_output: |
| 57 | + return ExtractedToolCallInformation(tools_called=False, |
| 58 | + tool_calls=[], |
| 59 | + content=model_output) |
| 60 | + try: |
| 61 | + match = self.tool_call_regex.search(model_output) |
| 62 | + if not match: |
| 63 | + raise ValueError("No action block found") |
| 64 | + payload = match.group(1) |
| 65 | + raw_calls = json.loads(payload) |
| 66 | + tool_calls = [] |
| 67 | + for entry in raw_calls: |
| 68 | + name = entry.get("tool_name") |
| 69 | + params = entry.get("parameters", {}) |
| 70 | + tool_calls.append( |
| 71 | + ToolCall(type="function", |
| 72 | + function=FunctionCall(name=name, |
| 73 | + arguments=json.dumps( |
| 74 | + params, |
| 75 | + ensure_ascii=False)))) |
| 76 | + # content before action |
| 77 | + prefix = model_output.split(self.tool_call_start_token, 1)[0] |
| 78 | + return ExtractedToolCallInformation(tools_called=True, |
| 79 | + tool_calls=tool_calls, |
| 80 | + content=prefix or None) |
| 81 | + except Exception: |
| 82 | + logger.exception("Error extracting sync tool calls") |
| 83 | + return ExtractedToolCallInformation(tools_called=False, |
| 84 | + tool_calls=[], |
| 85 | + content=model_output) |
| 86 | + |
| 87 | + def extract_tool_calls_streaming( |
| 88 | + self, |
| 89 | + previous_text: str, |
| 90 | + current_text: str, |
| 91 | + delta_text: str, |
| 92 | + previous_token_ids: Sequence[int], |
| 93 | + current_token_ids: Sequence[int], |
| 94 | + delta_token_ids: Sequence[int], |
| 95 | + request: ChatCompletionRequest, |
| 96 | + ) -> Union[DeltaMessage, None]: |
| 97 | + |
| 98 | + prev_start = previous_token_ids.count(self.tool_call_start_token_id) |
| 99 | + cur_start = current_token_ids.count(self.tool_call_start_token_id) |
| 100 | + cur_end = current_token_ids.count(self.tool_call_end_token_id) |
| 101 | + |
| 102 | + # Case 1: Block not started → Text as is |
| 103 | + if cur_start == 0: |
| 104 | + return DeltaMessage(content=delta_text) |
| 105 | + |
| 106 | + # Case 2: Starting a new block |
| 107 | + if cur_start > prev_start: |
| 108 | + self.current_tool_id += 1 |
| 109 | + return None |
| 110 | + |
| 111 | + # Case 3: Inside block, not closed → ignored |
| 112 | + if cur_start > cur_end: |
| 113 | + return None |
| 114 | + |
| 115 | + # Case 4: Block End Point |
| 116 | + if cur_start == cur_end and self.tool_call_end_token in delta_text: |
| 117 | + full = current_text + delta_text |
| 118 | + |
| 119 | + payload = full.split(self.tool_call_start_token, 1)[1] \ |
| 120 | + .split(self.tool_call_end_token, 1)[0].strip() |
| 121 | + try: |
| 122 | + calls = partial_json_parser.loads(payload or "[]", Allow.ALL) |
| 123 | + except partial_json_parser.core.exceptions.MalformedJSON: |
| 124 | + logger.debug("Waiting for complete JSON") |
| 125 | + return None |
| 126 | + except json.JSONDecodeError: |
| 127 | + logger.debug("Malformed JSON payload: %s", payload) |
| 128 | + return None |
| 129 | + |
| 130 | + calls_list = calls if isinstance(calls, list) else [calls] |
| 131 | + deltas = [] |
| 132 | + for entry in calls_list: |
| 133 | + name = entry.get("tool_name") |
| 134 | + params = entry.get("parameters", {}) |
| 135 | + args = json.dumps(params, ensure_ascii=False) |
| 136 | + deltas.append( |
| 137 | + DeltaToolCall( |
| 138 | + index=self.current_tool_id, |
| 139 | + type="function", |
| 140 | + id=f"chatcmpl-tool-{random_uuid()}", |
| 141 | + function=DeltaFunctionCall( |
| 142 | + name=name, |
| 143 | + arguments=args, |
| 144 | + ).model_dump(exclude_none=True), |
| 145 | + )) |
| 146 | + |
| 147 | + self.current_tool_id += 1 |
| 148 | + |
| 149 | + return DeltaMessage(tool_calls=deltas) |
| 150 | + |
| 151 | + return DeltaMessage(content=delta_text) |
0 commit comments