diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_use/test_mistral_tool_parser.py new file mode 100644 index 00000000000..4b5a19be2eb --- /dev/null +++ b/tests/tool_use/test_mistral_tool_parser.py @@ -0,0 +1,484 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +from collections.abc import Generator +from typing import Optional + +import partial_json_parser +import pytest +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall, + ToolCall) +from vllm.entrypoints.openai.tool_parsers import MistralToolParser +from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +MODEL = "jeffcookio/Mistral-Small-3.2-24B-Instruct-2506-awq-sym" + + +@pytest.fixture(scope="module") +def mistral_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture +def mistral_tool_parser(mistral_tokenizer): + return MistralToolParser(mistral_tokenizer) + + +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) == 9 + + assert actual_tool_call.type == "function" + assert actual_tool_call.function == expected_tool_call.function, ( + f'got ${actual_tool_call.function}') + + +def stream_delta_message_generator( + mistral_tool_parser: MistralToolParser, + mistral_tokenizer: AnyTokenizer, + model_output: str) -> Generator[DeltaMessage, None, None]: + all_token_ids = mistral_tokenizer.encode(model_output, + add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[:i + 1] + + (new_tokens, delta_text, new_prefix_offset, + new_read_offset) = detokenize_incrementally( + tokenizer=mistral_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + + current_text = previous_text + delta_text + + delta_message = mistral_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=None, # type: ignore[arg-type] + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = previous_tokens + new_tokens if previous_tokens\ + else new_tokens + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +def test_extract_tool_calls_no_tools(mistral_tool_parser): + model_output = "This is a test" + extracted_tool_calls = mistral_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", "single_tool_weather", "argument_before_name", + "argument_before_name_and_name_in_argument" + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + '''[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3.5, + "b": 4 + }))) + ], + None), + ( + '''[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + None), + ( + '''[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + None), + ( + '''[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_age", + arguments=json.dumps({ + "name": + "John Doe", + }))) + ], + None), + ], +) +def test_extract_tool_calls(mistral_tool_parser, model_output, + expected_tool_calls, expected_content): + extracted_tool_calls = mistral_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +@pytest.mark.parametrize( + ids=[ + "no_tools", + "single_tool_add", + "single_tool_add_strings", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", + "multiple_tools", + "v11_single_tool", + "v11_multiple_tools", + "v11_nested_json", + "v11_special_chars", + "v11_empty_args", + "v11_complex_nested", + "v11_with_comma_separator", + "v11_with_whitespace_separator", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ('''This is a test''', [], '''This is a test'''), + ( + '''[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3, + "b": 4 + }))) + ], + ""), + ( + '''[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": "3", + "b": "4" + }))) + ], + ""), + ( + '''[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + ""), + ( + '''[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + ''), + ( + '''[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_age", + arguments=json.dumps({ + "name": + "John Doe", + }))) + ], + ''), + ( + '''[TOOL_CALLS][{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3.5, + "b": 4 + }))), + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + ''), + # V11 format tests + ( + '''[TOOL_CALLS] add{"a": 3, "b": 4}''', + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3, + "b": 4 + }))) + ], + ""), + ( + '''[TOOL_CALLS] add{"a": 3, "b": 4}, get_weather{"city": "Paris", "unit": "celsius"}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3, + "b": 4 + }))), + ToolCall(function=FunctionCall(name="get_weather", + arguments=json.dumps({ + "city": "Paris", + "unit": "celsius" + }))) + ], + ""), + ( + '''[TOOL_CALLS] process_data{"input": {"nested": {"value": 42, "array": [1, 2, 3]}, "flag": true}}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="process_data", + arguments=json.dumps({ + "input": { + "nested": { + "value": 42, + "array": [1, 2, 3] + }, + "flag": True + } + }))) + ], + ""), + ( + '''[TOOL_CALLS] send_message{"text": "Hello, it's a nice day!", "recipient": "user@example.com"}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="send_message", + arguments=json.dumps({ + "text": "Hello, it's a nice day!", + "recipient": "user@example.com" + }))) + ], + ""), + ( + '''[TOOL_CALLS] empty_function{}''', + [ + ToolCall(function=FunctionCall(name="empty_function", + arguments=json.dumps({}))) + ], + ""), + ( + '''[TOOL_CALLS] complex_tool{"data": {"items": [{"id": 1, "props": {"key": "value"}}, {"id": 2, "props": {"key": "other"}}], "meta": {"count": 2}}}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="complex_tool", + arguments=json.dumps({ + "data": { + "items": [ + {"id": 1, "props": {"key": "value"}}, + {"id": 2, "props": {"key": "other"}} + ], + "meta": {"count": 2} + } + }))) + ], + ""), + ( + '''[TOOL_CALLS] first_tool{"x": 1}, second_tool{"y": 2}''', + [ + ToolCall(function=FunctionCall(name="first_tool", + arguments=json.dumps({"x": 1}))), + ToolCall(function=FunctionCall(name="second_tool", + arguments=json.dumps({"y": 2}))) + ], + ""), + ( + '''[TOOL_CALLS] tool_a{"param": "A"} tool_b{"param": "B"}''', + [ + ToolCall(function=FunctionCall(name="tool_a", + arguments=json.dumps({"param": "A"}))), + ToolCall(function=FunctionCall(name="tool_b", + arguments=json.dumps({"param": "B"}))) + ], + ""), + ], +) +def test_extract_tool_calls_streaming(mistral_tool_parser, mistral_tokenizer, + model_output, expected_tool_calls, + expected_content): + other_content: str = '' + function_names: list[str] = [] + function_args_strs: list[str] = [] + tool_call_idx: int = -1 + tool_call_ids: list[Optional[str]] = [] + + for delta_message in stream_delta_message_generator( + mistral_tool_parser, mistral_tokenizer, model_output): + # role should never be streamed from tool parser + assert not delta_message.role + + if delta_message.content: + other_content += delta_message.content + + streamed_tool_calls = delta_message.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + # make sure only one diff is present - correct even for parallel + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a new tool is being called, set up empty arguments + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + function_args_strs.append("") + tool_call_ids.append(None) + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id and not tool_call_ids[tool_call.index]: + tool_call_ids[tool_call.index] = tool_call.id + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert isinstance(tool_call.function.name, str) + function_names.append(tool_call.function.name) + + if tool_call.function.arguments: + # make sure they're a string and then add them to the list + assert isinstance(tool_call.function.arguments, str) + + function_args_strs[ + tool_call.index] += tool_call.function.arguments + + assert other_content == expected_content + + actual_tool_calls = [ + ToolCall(id=tool_call_id, + function=FunctionCall( + name=function_name, + arguments=partial_json_parser.ensure_json( + function_args_str, Allow.OBJ | Allow.STR))) + for tool_call_id, function_name, function_args_str in zip( + tool_call_ids, function_names, function_args_strs) + ] + assert_tool_calls(actual_tool_calls, expected_tool_calls) + + +@pytest.mark.parametrize( + ids=[ + "v11_single_tool", + "v11_multiple_tools_comma", + "v11_nested_with_quotes", + "v11_escaped_chars", + "v11_mixed_content", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + '''[TOOL_CALLS] calculate_sum{"numbers": [1, 2, 3, 4, 5]}''', + [ + ToolCall(function=FunctionCall(name="calculate_sum", + arguments=json.dumps({ + "numbers": [1, 2, 3, 4, 5] + }))) + ], + None), + ( + '''[TOOL_CALLS] get_user{"id": 123}, update_profile{"name": "John", "age": 30}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_user", + arguments=json.dumps({"id": 123}))), + ToolCall(function=FunctionCall(name="update_profile", + arguments=json.dumps({ + "name": "John", + "age": 30 + }))) + ], + None), + ( + '''[TOOL_CALLS] parse_json{"content": "{\\"key\\": \\"value\\", \\"nested\\": {\\"item\\": 1}}"}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="parse_json", + arguments=json.dumps({ + "content": "{\"key\": \"value\", \"nested\": {\"item\": 1}}" + }))) + ], + None), + ( + '''[TOOL_CALLS] format_text{"template": "Hello {name}\\nWelcome!", "vars": {"name": "User"}}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="format_text", + arguments=json.dumps({ + "template": "Hello {name}\nWelcome!", + "vars": {"name": "User"} + }))) + ], + None), + ( + '''Some content before [TOOL_CALLS] analyze_data{"dataset": "sales_2024", "metrics": ["revenue", "growth"]}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="analyze_data", + arguments=json.dumps({ + "dataset": "sales_2024", + "metrics": ["revenue", "growth"] + }))) + ], + "Some content before "), + ], +) +def test_extract_tool_calls_v11_format(mistral_tool_parser, model_output, + expected_tool_calls, expected_content): + """Test extraction of tool calls in v11 format (non-streaming)""" + extracted_tool_calls = mistral_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index c0691f12290..ff56868f6da 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -3,13 +3,12 @@ import json from collections.abc import Sequence +from enum import Enum from random import choices from string import ascii_letters, digits -from typing import Union +from typing import Literal, Union -import partial_json_parser import regex as re -from partial_json_parser.core.options import Allow from pydantic import Field from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -19,8 +18,6 @@ FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import ( - extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer @@ -29,6 +26,15 @@ ALPHANUMERIC = ascii_letters + digits +class StreamingState(Enum): + """Enum for tracking the current streaming parsing state.""" + WAITING_FOR_TOOL_START = "waiting_for_tool_start" + PARSING_NAME = "parsing_name" + PARSING_ARGUMENTS = "parsing_arguments" + TOOL_COMPLETE = "tool_complete" + ALL_TOOLS_COMPLETE = "all_tools_complete" + + class MistralToolCall(ToolCall): id: str = Field( default_factory=lambda: MistralToolCall.generate_random_id()) @@ -68,17 +74,50 @@ def __init__(self, tokenizer: AnyTokenizer): # initialize properties used for state when parsing tool calls in # streaming mode - self.prev_tool_call_arr: list[dict] = [] + self.json_decoder: json.JSONDecoder = json.JSONDecoder() + + # Optimized regex patterns + self.tool_call_first_attribute_name: re.Pattern[str] = re.compile( + r'.*\s*"name"\s*:\s*') + self.string_value_pattern: re.Pattern[str] = re.compile( + r'\s*"(.*?)(? Union[DeltaMessage, None]: + """ + Extract tool calls from a streaming response, specifically for the + v11 MistralTokenizer format: ToolName{arguments}. This logic is a + streaming equivalent of the `self.fn_name_regex` used in + non-streaming extraction. + """ + logger.debug("v11 streaming: raw_tool_calls='%s'", self.raw_tool_calls) + logger.debug("v11 streaming: current_tool_name_sent='%s'", + self.current_tool_name_sent) + logger.debug("v11 streaming: prev_args_sent='%s'", self.prev_args_sent) + + result_tool_calls: list[DeltaToolCall] = [] + + while True: + advanced = False + if self.current_tool_name_finished and \ + self.current_tool_arguments_finished and \ + self._should_advance_to_next_v11_tool(): + # Remove the completed tool from raw_tool_calls + # before resetting state + completed_tool_end = self._find_completed_v11_tool_end() + if completed_tool_end > 0: + self.raw_tool_calls = self.raw_tool_calls[ + completed_tool_end:] + self._reset_v11_tool_state() + logger.debug("v11 streaming: found next tool, resetting state") + advanced = True + + sent_something = False + + # Phase 1: Extract and send function name + if not self.current_tool_name_sent: + # Look for function name pattern: name followed by { + brace_index = self.raw_tool_calls.find("{") + if brace_index == -1: + logger.debug("v11 streaming: no opening brace found yet") + break + + # Extract function name + func_name = self.raw_tool_calls[:brace_index].strip() + # Remove any leading separators from previous tools + func_name = re.sub(r'^[\s,]*', '', func_name) + + if not func_name: + logger.debug("v11 streaming: function name is empty") + break + + logger.debug("v11 streaming: sending function name='%s'", + func_name) + self.current_tool_name_sent = True + self.current_tool_name_finished = True + self.current_tool_id += 1 + + result_tool_calls.append( + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=MistralToolCall.generate_random_id(), + function=DeltaFunctionCall(name=func_name).model_dump( + exclude_none=True), + )) + sent_something = True + + # Phase 2: Extract and send argument fragments + if self.current_tool_name_sent and \ + not self.current_tool_arguments_finished: + # Find the arguments part (everything after the first {) + brace_index = self.raw_tool_calls.find("{") + if brace_index == -1: + logger.debug( + "v11 streaming: no opening brace found for args") + break + + current_args = self.raw_tool_calls[brace_index:] + logger.debug("v11 streaming: current_args='%s'", current_args) + + actual_args = current_args + try: + parsed_obj, end_idx = self.json_decoder.raw_decode( + current_args) + # JSON is complete + self.current_tool_arguments_finished = True + actual_args = current_args[:end_idx] + logger.debug("v11 streaming: JSON complete, parsed_obj=%s", + parsed_obj) + except json.decoder.JSONDecodeError: + # JSON still incomplete + logger.debug("v11 streaming: JSON still incomplete") + pass + + # Calculate what's new since last time + new_content = "" + if actual_args != self.prev_args_sent: + if self.prev_args_sent and actual_args.startswith( + self.prev_args_sent): + # Incremental update + new_content = actual_args[len(self.prev_args_sent):] + logger.debug("v11 streaming: incremental args='%s'", + new_content) + else: + # First time or reset + new_content = actual_args + logger.debug("v11 streaming: first/reset args='%s'", + new_content) + + self.prev_args_sent = actual_args + + if new_content: + result_tool_calls.append( + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=new_content).model_dump( + exclude_none=True), + )) + sent_something = True + + if not sent_something and not advanced: + break + + if result_tool_calls: + return DeltaMessage( + content=additional_content, + tool_calls=result_tool_calls, + ) + + return self._none_or_additional_content(additional_content) + + def _should_advance_to_next_v11_tool(self) -> bool: + """Check if we should advance to the next tool in V11 format.""" + completed_tool_end = self._find_completed_v11_tool_end() + if completed_tool_end <= 0: + return False + + # Check if there's content after the completed tool + # that looks like another tool + remaining = self.raw_tool_calls[completed_tool_end:].strip() + if remaining.startswith(','): + remaining = remaining[1:].strip() + + # Look for next tool pattern: function_name{ + return bool(re.match(r'[a-zA-Z0-9_-]+\s*\{', remaining)) + + def _find_completed_v11_tool_end(self) -> int: + """ + Find the end position of the first completed tool in V11 format + using JSON parsing. + """ + # Look for function name pattern: name followed by { + brace_match = re.search(r'([a-zA-Z0-9_-]+)\s*(\{)', + self.raw_tool_calls) + if not brace_match: + return -1 + + # Try to parse the JSON starting from the opening brace + json_start = brace_match.start(2) + json_part = self.raw_tool_calls[json_start:] + + try: + _, end_idx = self.json_decoder.raw_decode(json_part) + return json_start + end_idx + except json.JSONDecodeError: + return -1 + + def _reset_v11_tool_state(self) -> None: + """Reset V11 tool parsing state for the next tool.""" + self.current_tool_name_finished = False + self.current_tool_arguments_finished = False + self.current_tool_name_sent = False + self.prev_args_sent = "" + + def _determine_next_parsing_element(self) \ + -> Union[Literal["name", "arguments"], None]: + """ + Determine the next element to parse based on current state. + + Returns: + The next element to parse, or None if nothing is ready + """ + # Check for name attribute + if not self.current_tool_name_finished: + match_name = self.tool_call_first_attribute_name.match( + self.raw_tool_calls, self.current_tool_start_index) + if match_name and match_name.end( + ) > self.current_tool_start_index \ + + self.previous_attribute_end_index: + self.current_attribute_start_index = match_name.end() \ + - self.current_tool_start_index + return "name" + + # Check for arguments attribute + if not self.current_tool_arguments_finished: + match_arguments = self.tool_call_first_attribute_arguments.match( + self.raw_tool_calls, self.current_tool_start_index) + if match_arguments and match_arguments.end( + ) > self.current_tool_start_index \ + + self.previous_attribute_end_index: + # The `{` is the last character in the match. + # We want it as start index. + self.current_attribute_start_index = match_arguments.end() \ + - 1 - self.current_tool_start_index + return "arguments" + + return None + + def _is_current_tool_complete(self) -> bool: + """Check if the current tool parsing is complete.""" + return (self.current_tool_name_finished + and self.current_tool_arguments_finished) + + def _advance_to_next_tool(self) -> bool: + """ + Advance to the next tool if available. + + Returns: + True if successfully advanced to next tool, False otherwise + """ + next_tool_start_index = self._next_tool_starting_position() + if next_tool_start_index > 0: + self.current_tool_id += 1 + self.current_tool_start_index = next_tool_start_index + self.current_attribute_start_index = -1 + self.previous_attribute_end_index = 0 + self.current_tool_name_finished = False + self.current_tool_arguments_finished = False + return True + return False + + def _process_delta_text(self, delta_text: str) -> str: + """ + Process delta text and update raw_tool_calls, returning any additional + content. + + Args: + delta_text: The new text delta to process + + Returns: + Any additional content that appears before the bot token + """ + additional_content = "" + + if self.bot_token in delta_text: + # Split only once for efficiency + parts = delta_text.split(self.bot_token, 1) + if len(parts) > 1: + if parts[0]: # Content before bot token + additional_content = parts[0] + # Process content after bot token + tool_content = parts[1].lstrip() + self.raw_tool_calls += tool_content + else: + # No bot token in delta, just clean and append + self.raw_tool_calls += delta_text + # Remove leading spaces only if we have content + if self.raw_tool_calls: + self.raw_tool_calls = self.raw_tool_calls.lstrip() + + return additional_content + + def _should_detect_v11_format(self) -> bool: + """Check if we should attempt V11 format detection.""" + return (self.fn_name_regex is not None and self.current_tool_id == -1 + and not self.v11_tool_format) + + def _detect_v11_format(self) -> None: + """Detect if we're using V11 tool format.""" + stripped_calls = self.raw_tool_calls.lstrip() + if stripped_calls and stripped_calls[0] != "[": + logger.debug("flipping v11 tool format to True ...") + self.v11_tool_format = True + + def _try_parse_json_cached(self, text: str) -> tuple[bool, int]: + """ + Attempt to parse JSON with caching for performance. + + Args: + text: The text to parse as JSON + + Returns: + Tuple of (success, end_index) + """ + if text == self._last_json_parse_input: + return self._last_json_parse_result + + try: + _, end_index = self.json_decoder.raw_decode(text) + result = (True, end_index) + except json.decoder.JSONDecodeError: + result = (False, -1) + + # Cache the result + self._last_json_parse_input = text + self._last_json_parse_result = result + return result + + def _extracted_complete_name( + self, current_attribute_start_index: int) \ + -> tuple[str, Union[int, None]]: + """ + Extract the complete function name from the current tool call. + + Args: + current_attribute_start_index: The starting index of the + name attribute relative to the current tool start + + Returns: + tuple: + - The function name, or "" if extraction failed + - The end index of the name relative to the current tool start, + or None if extraction failed + """ + absolute_start = self.current_tool_start_index \ + + current_attribute_start_index + if match := self.string_value_pattern.match(\ + self.raw_tool_calls, absolute_start): + return match.group(1), match.end() - self.current_tool_start_index + return "", None + + def _extract_argument_fragment(self, current_attribute_start_index: int, + delta: str) -> tuple[str, int]: + """ + Extract the relevant argument fragment from the current streaming delta. + + Args: + current_attribute_start_index: The starting index + of the arguments attribute relative to the current tool start + delta: The new text added in this streaming step + + Returns: + tuple: + - The extracted argument diff text + to be sent in the streaming response + - The end index of the arguments relative to the current tool start, + or -1 if not yet complete + """ + absolute_start = self.current_tool_start_index \ + + current_attribute_start_index + partial_arguments_value = self.raw_tool_calls[absolute_start:] + try: + _, end_index = self.json_decoder.raw_decode( + partial_arguments_value) + return ( + delta[:len(delta) + end_index - len(partial_arguments_value)], + current_attribute_start_index + end_index, + ) + except json.decoder.JSONDecodeError: + # The arguments object is not complete + + # delta contains data from before the argument start + if len(delta) > len(partial_arguments_value): + return delta[-len(partial_arguments_value):], -1 + + # We can send the whole delta + return delta, -1 + + def _next_tool_starting_position(self) -> int: + """ + Find the starting position of the next tool + in the raw tool calls string. + + Returns: + The index position where the next tool starts, + or -1 if no next tool is found yet + """ + assert self.current_tool_start_index >= 0 + try: + _, end_index = self.json_decoder.raw_decode( + self.raw_tool_calls, self.current_tool_start_index) + # Look for the next opening brace after the current tool ends + search_start = self.current_tool_start_index + end_index + next_brace = self.raw_tool_calls.find("{", search_start) + return next_brace if next_brace != -1 else -1 + except json.decoder.JSONDecodeError: + # The current tool object is not yet closed + return -1 + except IndexError: + # The next tool has not started yet + # and the delta just closes the current tool call + return -1 + + def _none_or_additional_content( + self, additional_content: str) -> Union[DeltaMessage, None]: + """ + Create a DeltaMessage with additional content if present, + otherwise return None. + + Args: + additional_content: The text content to include in the message + + Returns: + A DeltaMessage with the additional content, + or None if no content is provided + """ + if additional_content: + return DeltaMessage(content=additional_content) + return None + def adjust_request( self, request: ChatCompletionRequest) -> ChatCompletionRequest: if not isinstance( @@ -125,19 +564,55 @@ def extract_tool_calls( # jsons is difficult try: if self.fn_name_regex: - matches = self.fn_name_regex.findall(tool_content) - function_call_arr = [] - for match in matches: - fn_name = match[0] - args = match[1] - - # fn_name is encoded outside serialized json dump - # only arguments are serialized - function_call_arr.append({ - "name": fn_name, - "arguments": json.loads(args) - }) + pos = 0 + tool_str = tool_content + while pos < len(tool_str): + # skip ws + while pos < len(tool_str) and tool_str[pos].isspace(): + pos += 1 + if pos >= len(tool_str): + break + + # match name + match_name = re.match(r'([a-zA-Z0-9_-]+)', + tool_str[pos:]) + if not match_name: + break + fn_name = match_name.group(0) + pos += match_name.end() + + # skip ws + while pos < len(tool_str) and tool_str[pos].isspace(): + pos += 1 + + if pos >= len(tool_str) or tool_str[pos] != '{': + break + + pos += 1 # skip { + + # parse args + try: + args_obj, end_idx = self.json_decoder.raw_decode( + tool_str[pos:]) + function_call_arr.append({ + "name": fn_name, + "arguments": args_obj + }) + pos += end_idx + except json.JSONDecodeError: + break + + # skip ws + while pos < len(tool_str) and tool_str[pos].isspace(): + pos += 1 + + # optional comma + if pos < len(tool_str) and tool_str[pos] == ',': + pos += 1 + while pos < len( + tool_str) and tool_str[pos].isspace(): + pos += 1 else: function_call_arr = json.loads(tool_content) except json.JSONDecodeError: @@ -145,7 +620,8 @@ def extract_tool_calls( # NOTE: This use case should not happen if the model is trained # correctly. It's a easy possible fix so it's included, but # can be brittle for very complex / highly nested tool calls - raw_tool_call = self.tool_call_regex.findall(tool_content)[0] + raw_tool_call = self.tool_call_regex.search( + tool_content).group(0) function_call_arr = json.loads(raw_tool_call) # Tool Call @@ -185,185 +661,114 @@ def extract_tool_calls_streaming( request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - # if the tool call token is not in the tokens generated so far, append - # output to contents since it's not a tool + # Early return if no tool call token present if self.bot_token not in current_text: return DeltaMessage(content=delta_text) - # if the tool call token ID IS in the tokens generated so far, that - # means we're parsing as tool calls now - - # handle if we detected the BOT token which means the start of tool - # calling - if (self.bot_token_id in delta_token_ids - and len(delta_token_ids) == 1): - # if it's the only token, return None, so we don't send a chat - # completion any don't send a control token - return None - - # bit mask flags for partial JSON parsing. If the name hasn't been - # sent yet, don't allow sending - # an incomplete string since OpenAI only ever (as far as I have - # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR - try: - - # replace BOT token with empty string, and convert single quotes - # to double to allow parsing as JSON since mistral uses single - # quotes instead of double for tool calls - parsable_arr = current_text.split(self.bot_token)[-1] - - # tool calls are generated in an array, so do partial JSON - # parsing on the entire array - try: - tool_call_arr: list[dict] = partial_json_parser.loads( - parsable_arr, flags) - except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') - return None - - # select as the current tool call the one we're on the state at - - current_tool_call: dict = tool_call_arr[self.current_tool_id] \ - if len(tool_call_arr) > 0 else {} - - # case -- if no tokens have been streamed for the tool, e.g. - # only the array brackets, stream nothing - if len(tool_call_arr) == 0: - return None - - # case: we are starting a new tool in the array - # -> array has > 0 length AND length has moved past cursor - elif (len(tool_call_arr) > 0 - and len(tool_call_arr) > self.current_tool_id + 1): - - # if we're moving on to a new call, first make sure we - # haven't missed anything in the previous one that was - # auto-generated due to JSON completions, but wasn't - # streamed to the client yet. - if self.current_tool_id >= 0: - diff: Union[str, None] = current_tool_call.get("arguments") - - if diff: - diff = json.dumps(diff, ensure_ascii=False).replace( - self.streamed_args_for_tool[self.current_tool_id], - "") - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += diff - else: - delta = None - else: - delta = None - # re-set stuff pertaining to progress in the current tool - self.current_tool_id = len(tool_call_arr) - 1 - self.current_tool_name_sent = False - self.streamed_args_for_tool.append("") - logger.debug("starting on new tool %d", self.current_tool_id) - return delta + # Process delta text and extract additional content + additional_content = self._process_delta_text(delta_text) - # case: update an existing tool - this is handled below + # Detect and handle V11 format + if self._should_detect_v11_format(): + self._detect_v11_format() - # if the current tool name hasn't been sent, send if available - # - otherwise send nothing - if not self.current_tool_name_sent: - function_name = current_tool_call.get("name") - if function_name: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=MistralToolCall.generate_random_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) - self.current_tool_name_sent = True - else: - delta = None + if self.v11_tool_format: + return self._extract_tool_calls_streaming_v11( + additional_content, delta_text) - # now we know we're on the same tool call and we're streaming - # arguments + # Check if tool calls have started + if self.current_tool_start_index < 0: + bracket_pos = self.raw_tool_calls.find("[") + if bracket_pos >= 0: + self.current_tool_start_index = bracket_pos + 1 + self.current_tool_id += 1 + else: + return self._none_or_additional_content(additional_content) + + # Try to parse complete JSON with caching + parse_success, end_index = self._try_parse_json_cached( + self.raw_tool_calls) + if parse_success: + self.tools_parsing_finished = True + if len(self.raw_tool_calls) > end_index: + additional_content = self.raw_tool_calls[end_index:] + + # Handle tool completion and transition to next tool + if self._is_current_tool_complete(): + if self.tools_parsing_finished: + return self._none_or_additional_content(additional_content) + + if self._advance_to_next_tool(): + # Successfully moved to next tool, continue processing + pass else: + # No next tool ready yet + return self._none_or_additional_content(additional_content) - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") - cur_arguments = current_tool_call.get("arguments") - - new_text = delta_text.replace("\'", "\"") - if ('"}' in new_text): - new_text = new_text[:new_text.rindex('"}')] - - if not cur_arguments and not prev_arguments: - - delta = None - elif not cur_arguments and prev_arguments: - logger.error( - "INVARIANT - impossible to have arguments reset " - "mid-arguments") - delta = None - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False)[:-2] - logger.debug("finding %s in %s", new_text, - cur_arguments_json) - - if (new_text not in cur_arguments_json): - return None - arguments_delta = cur_arguments_json[:cur_arguments_json. - rindex(new_text) + - len(new_text)] - logger.debug("First tokens in arguments received: %s", - arguments_delta) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta - - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) - logger.debug("Searching for diff between \n%s\n%s", - cur_args_json, prev_args_json) - - argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) - logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff - else: - # try parsing it with regular JSON - if it works we're - # at the end, and we need to send the difference between - # tokens streamed so far and the valid JSON - delta = None + if self.current_tool_start_index >= len(self.raw_tool_calls): + # tool call has not started + return self._none_or_additional_content(additional_content) - # check to see if the name is defined and has been sent. if so, - # stream the name - otherwise keep waiting - # finish by setting old and returning None as base case - self.prev_tool_call_arr = tool_call_arr - return delta + # Determine what to parse next + if self.current_element_streaming is None: + next_element = self._determine_next_parsing_element() + if next_element is None: + return self._none_or_additional_content(additional_content) + self.current_element_streaming = next_element - except Exception: - logger.exception("Error trying to handle streaming tool call.") - logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") - return None + if self.current_element_streaming == "name": + try: + function_name, name_end_index = self._extracted_complete_name( + self.current_attribute_start_index) + except IndexError: + # name value has not started being generated + return self._none_or_additional_content(additional_content) + if function_name == "": + return self._none_or_additional_content(additional_content) + else: + assert name_end_index is not None + # because the function name was successfully retrieved + + self.current_tool_name_finished = True + self.current_element_streaming = None + self.current_attribute_start_index = -1 + self.previous_attribute_end_index = name_end_index + delta = DeltaMessage( + content=additional_content, + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=MistralToolCall.generate_random_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ], + ) + return delta + if self.current_element_streaming == "arguments": + try: + diff, arguments_end_index = self._extract_argument_fragment( + self.current_attribute_start_index, + delta_text, + ) + self.current_tool_arguments_finished = arguments_end_index != -1 + if self.current_tool_arguments_finished: + self.current_element_streaming = None + self.current_attribute_start_index = -1 + self.previous_attribute_end_index = arguments_end_index + delta = DeltaMessage( + content=additional_content, + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ], + ) + return delta + except IndexError: + # arguments value has not started being generated + return self._none_or_additional_content(additional_content)