-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
feat: Add streaming support for Mistral v11 tool format #20503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sjuxax
wants to merge
12
commits into
vllm-project:main
Choose a base branch
from
sjuxax:Mistral3.2-tool-call-fix
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,083
−194
Open
Changes from 2 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
976dc82
feat: Add streaming support for v11 tool format in Mistral parser
sjuxax c78d1fb
Bring in tests from #19425
avigny 2dedc6e
Update vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
sjuxax 5966d37
fix: prevent infinite loop in Mistral tool parsing by removing proces…
sjuxax 9670aee
refactor: improve JSON parsing for Mistral tool calls with robust reg…
sjuxax f280821
refactor: Improve quote normalization in tool call parsing to prevent…
sjuxax ed3dc1d
refactor: remove quote normalization from Mistral tool parser
sjuxax d089554
refactor: optimize tool call parsing by removing substring operations…
sjuxax dee4d43
refactor: Replace `X | Y` union syntax with `Union` for Python 3.9 co…
sjuxax b521f50
feat: add comprehensive tests for Mistral v11 tool format
sjuxax 2966baa
ruff/yapf
sjuxax ef4d46c
Via Grok4: attempt to fix non-streaming and multiple tool calls
sjuxax File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,311 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
from collections.abc import Generator | ||
from typing import Optional | ||
|
||
import partial_json_parser | ||
import pytest | ||
from partial_json_parser.core.options import Allow | ||
|
||
from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall, | ||
ToolCall) | ||
from vllm.entrypoints.openai.tool_parsers import MistralToolParser | ||
from vllm.transformers_utils.detokenizer import detokenize_incrementally | ||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer | ||
|
||
MODEL = "mistralai/Mistral-7B-Instruct-v0.3" | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def mistral_tokenizer(): | ||
return get_tokenizer(tokenizer_name=MODEL) | ||
|
||
|
||
@pytest.fixture | ||
def mistral_tool_parser(mistral_tokenizer): | ||
return MistralToolParser(mistral_tokenizer) | ||
|
||
|
||
def assert_tool_calls(actual_tool_calls: list[ToolCall], | ||
expected_tool_calls: list[ToolCall]): | ||
assert len(actual_tool_calls) == len(expected_tool_calls) | ||
|
||
for actual_tool_call, expected_tool_call in zip(actual_tool_calls, | ||
expected_tool_calls): | ||
assert isinstance(actual_tool_call.id, str) | ||
assert len(actual_tool_call.id) == 9 | ||
|
||
assert actual_tool_call.type == "function" | ||
assert actual_tool_call.function == expected_tool_call.function, ( | ||
f'got ${actual_tool_call.function}') | ||
|
||
|
||
def stream_delta_message_generator( | ||
mistral_tool_parser: MistralToolParser, | ||
mistral_tokenizer: AnyTokenizer, | ||
model_output: str) -> Generator[DeltaMessage, None, None]: | ||
all_token_ids = mistral_tokenizer.encode(model_output, | ||
add_special_tokens=False) | ||
|
||
previous_text = "" | ||
previous_tokens = None | ||
prefix_offset = 0 | ||
read_offset = 0 | ||
for i, delta_token in enumerate(all_token_ids): | ||
delta_token_ids = [delta_token] | ||
previous_token_ids = all_token_ids[:i] | ||
current_token_ids = all_token_ids[:i + 1] | ||
|
||
(new_tokens, delta_text, new_prefix_offset, | ||
new_read_offset) = detokenize_incrementally( | ||
tokenizer=mistral_tokenizer, | ||
all_input_ids=current_token_ids, | ||
prev_tokens=previous_tokens, | ||
prefix_offset=prefix_offset, | ||
read_offset=read_offset, | ||
skip_special_tokens=False, | ||
spaces_between_special_tokens=True, | ||
) | ||
|
||
current_text = previous_text + delta_text | ||
|
||
delta_message = mistral_tool_parser.extract_tool_calls_streaming( | ||
previous_text, | ||
current_text, | ||
delta_text, | ||
previous_token_ids, | ||
current_token_ids, | ||
delta_token_ids, | ||
request=None, # type: ignore[arg-type] | ||
) | ||
if delta_message: | ||
yield delta_message | ||
|
||
previous_text = current_text | ||
previous_tokens = previous_tokens + new_tokens if previous_tokens\ | ||
else new_tokens | ||
prefix_offset = new_prefix_offset | ||
read_offset = new_read_offset | ||
|
||
|
||
def test_extract_tool_calls_no_tools(mistral_tool_parser): | ||
model_output = "This is a test" | ||
extracted_tool_calls = mistral_tool_parser.extract_tool_calls( | ||
model_output, request=None) # type: ignore[arg-type] | ||
assert not extracted_tool_calls.tools_called | ||
assert extracted_tool_calls.tool_calls == [] | ||
assert extracted_tool_calls.content == model_output | ||
|
||
|
||
@pytest.mark.parametrize( | ||
ids=[ | ||
"single_tool_add", "single_tool_weather", "argument_before_name", | ||
"argument_before_name_and_name_in_argument" | ||
], | ||
argnames=["model_output", "expected_tool_calls", "expected_content"], | ||
argvalues=[ | ||
( | ||
'''[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]''', # noqa: E501 | ||
[ | ||
ToolCall(function=FunctionCall(name="add", | ||
arguments=json.dumps({ | ||
"a": 3.5, | ||
"b": 4 | ||
}))) | ||
], | ||
None), | ||
( | ||
'''[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]''', # noqa: E501 | ||
[ | ||
ToolCall(function=FunctionCall(name="get_current_weather", | ||
arguments=json.dumps( | ||
{ | ||
"city": "San Francisco", | ||
"state": "CA", | ||
"unit": "celsius" | ||
}))) | ||
], | ||
None), | ||
( | ||
'''[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]''', # noqa: E501 | ||
[ | ||
ToolCall(function=FunctionCall(name="get_current_weather", | ||
arguments=json.dumps( | ||
{ | ||
"city": "San Francisco", | ||
"state": "CA", | ||
"unit": "celsius" | ||
}))) | ||
], | ||
None), | ||
( | ||
'''[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501 | ||
[ | ||
ToolCall(function=FunctionCall(name="get_age", | ||
arguments=json.dumps({ | ||
"name": | ||
"John Doe", | ||
}))) | ||
], | ||
None), | ||
], | ||
) | ||
def test_extract_tool_calls(mistral_tool_parser, model_output, | ||
expected_tool_calls, expected_content): | ||
extracted_tool_calls = mistral_tool_parser.extract_tool_calls( | ||
model_output, request=None) # type: ignore[arg-type] | ||
assert extracted_tool_calls.tools_called | ||
|
||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) | ||
|
||
assert extracted_tool_calls.content == expected_content | ||
|
||
|
||
@pytest.mark.parametrize( | ||
ids=[ | ||
"no_tools", | ||
"single_tool_add", | ||
"single_tool_add_strings", | ||
"single_tool_weather", | ||
"argument_before_name", | ||
"argument_before_name_and_name_in_argument", | ||
"multiple_tools", | ||
], | ||
argnames=["model_output", "expected_tool_calls", "expected_content"], | ||
argvalues=[ | ||
('''This is a test''', [], '''This is a test'''), | ||
( | ||
'''[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]''', # noqa: E501 | ||
[ | ||
ToolCall(function=FunctionCall(name="add", | ||
arguments=json.dumps({ | ||
"a": 3, | ||
"b": 4 | ||
}))) | ||
], | ||
""), | ||
( | ||
'''[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]''', # noqa: E501 | ||
[ | ||
ToolCall(function=FunctionCall(name="add", | ||
arguments=json.dumps({ | ||
"a": "3", | ||
"b": "4" | ||
}))) | ||
], | ||
""), | ||
( | ||
'''[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]''', # noqa: E501 | ||
[ | ||
ToolCall(function=FunctionCall(name="get_current_weather", | ||
arguments=json.dumps( | ||
{ | ||
"city": "San Francisco", | ||
"state": "CA", | ||
"unit": "celsius" | ||
}))) | ||
], | ||
""), | ||
( | ||
'''[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]''', # noqa: E501 | ||
[ | ||
ToolCall(function=FunctionCall(name="get_current_weather", | ||
arguments=json.dumps( | ||
{ | ||
"city": "San Francisco", | ||
"state": "CA", | ||
"unit": "celsius" | ||
}))) | ||
], | ||
''), | ||
( | ||
'''[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501 | ||
[ | ||
ToolCall(function=FunctionCall(name="get_age", | ||
arguments=json.dumps({ | ||
"name": | ||
"John Doe", | ||
}))) | ||
], | ||
''), | ||
( | ||
'''[TOOL_CALLS][{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}]''', # noqa: E501 | ||
[ | ||
ToolCall(function=FunctionCall(name="add", | ||
arguments=json.dumps({ | ||
"a": 3.5, | ||
"b": 4 | ||
}))), | ||
ToolCall(function=FunctionCall(name="get_current_weather", | ||
arguments=json.dumps( | ||
{ | ||
"city": "San Francisco", | ||
"state": "CA", | ||
"unit": "celsius" | ||
}))) | ||
], | ||
''), | ||
], | ||
) | ||
def test_extract_tool_calls_streaming(mistral_tool_parser, mistral_tokenizer, | ||
model_output, expected_tool_calls, | ||
expected_content): | ||
other_content: str = '' | ||
function_names: list[str] = [] | ||
function_args_strs: list[str] = [] | ||
tool_call_idx: int = -1 | ||
tool_call_ids: list[Optional[str]] = [] | ||
|
||
for delta_message in stream_delta_message_generator( | ||
mistral_tool_parser, mistral_tokenizer, model_output): | ||
# role should never be streamed from tool parser | ||
assert not delta_message.role | ||
|
||
if delta_message.content: | ||
other_content += delta_message.content | ||
|
||
streamed_tool_calls = delta_message.tool_calls | ||
|
||
if streamed_tool_calls and len(streamed_tool_calls) > 0: | ||
# make sure only one diff is present - correct even for parallel | ||
assert len(streamed_tool_calls) == 1 | ||
tool_call = streamed_tool_calls[0] | ||
|
||
# if a new tool is being called, set up empty arguments | ||
if tool_call.index != tool_call_idx: | ||
tool_call_idx = tool_call.index | ||
function_args_strs.append("") | ||
tool_call_ids.append(None) | ||
|
||
# if a tool call ID is streamed, make sure one hasn't been already | ||
if tool_call.id and not tool_call_ids[tool_call.index]: | ||
tool_call_ids[tool_call.index] = tool_call.id | ||
|
||
# if parts of the function start being streamed | ||
if tool_call.function: | ||
# if the function name is defined, set it. it should be streamed | ||
# IN ENTIRETY, exactly one time. | ||
if tool_call.function.name: | ||
assert isinstance(tool_call.function.name, str) | ||
function_names.append(tool_call.function.name) | ||
|
||
if tool_call.function.arguments: | ||
# make sure they're a string and then add them to the list | ||
assert isinstance(tool_call.function.arguments, str) | ||
|
||
function_args_strs[ | ||
tool_call.index] += tool_call.function.arguments | ||
|
||
assert other_content == expected_content | ||
|
||
actual_tool_calls = [ | ||
ToolCall(id=tool_call_id, | ||
function=FunctionCall( | ||
name=function_name, | ||
arguments=partial_json_parser.ensure_json( | ||
function_args_str, Allow.OBJ | Allow.STR))) | ||
for tool_call_id, function_name, function_args_str in zip( | ||
tool_call_ids, function_names, function_args_strs) | ||
] | ||
assert_tool_calls(actual_tool_calls, expected_tool_calls) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.