Skip to content

feat: Add streaming support for Mistral v11 tool format #20503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 311 additions & 0 deletions tests/tool_use/test_mistral_tool_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
# SPDX-License-Identifier: Apache-2.0

import json
from collections.abc import Generator
from typing import Optional

import partial_json_parser
import pytest
from partial_json_parser.core.options import Allow

from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.tool_parsers import MistralToolParser
from vllm.transformers_utils.detokenizer import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer

MODEL = "mistralai/Mistral-7B-Instruct-v0.3"


@pytest.fixture(scope="module")
def mistral_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)


@pytest.fixture
def mistral_tool_parser(mistral_tokenizer):
return MistralToolParser(mistral_tokenizer)


def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
assert len(actual_tool_calls) == len(expected_tool_calls)

for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) == 9

assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function, (
f'got ${actual_tool_call.function}')


def stream_delta_message_generator(
mistral_tool_parser: MistralToolParser,
mistral_tokenizer: AnyTokenizer,
model_output: str) -> Generator[DeltaMessage, None, None]:
all_token_ids = mistral_tokenizer.encode(model_output,
add_special_tokens=False)

previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[:i + 1]

(new_tokens, delta_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=mistral_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)

current_text = previous_text + delta_text

delta_message = mistral_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=None, # type: ignore[arg-type]
)
if delta_message:
yield delta_message

previous_text = current_text
previous_tokens = previous_tokens + new_tokens if previous_tokens\
else new_tokens
prefix_offset = new_prefix_offset
read_offset = new_read_offset


def test_extract_tool_calls_no_tools(mistral_tool_parser):
model_output = "This is a test"
extracted_tool_calls = mistral_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output


@pytest.mark.parametrize(
ids=[
"single_tool_add", "single_tool_weather", "argument_before_name",
"argument_before_name_and_name_in_argument"
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
'''[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]''', # noqa: E501
[
ToolCall(function=FunctionCall(name="add",
arguments=json.dumps({
"a": 3.5,
"b": 4
})))
],
None),
(
'''[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "San Francisco",
"state": "CA",
"unit": "celsius"
})))
],
None),
(
'''[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "San Francisco",
"state": "CA",
"unit": "celsius"
})))
],
None),
(
'''[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_age",
arguments=json.dumps({
"name":
"John Doe",
})))
],
None),
],
)
def test_extract_tool_calls(mistral_tool_parser, model_output,
expected_tool_calls, expected_content):
extracted_tool_calls = mistral_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called

assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)

assert extracted_tool_calls.content == expected_content


@pytest.mark.parametrize(
ids=[
"no_tools",
"single_tool_add",
"single_tool_add_strings",
"single_tool_weather",
"argument_before_name",
"argument_before_name_and_name_in_argument",
"multiple_tools",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
('''This is a test''', [], '''This is a test'''),
(
'''[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]''', # noqa: E501
[
ToolCall(function=FunctionCall(name="add",
arguments=json.dumps({
"a": 3,
"b": 4
})))
],
""),
(
'''[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]''', # noqa: E501
[
ToolCall(function=FunctionCall(name="add",
arguments=json.dumps({
"a": "3",
"b": "4"
})))
],
""),
(
'''[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "San Francisco",
"state": "CA",
"unit": "celsius"
})))
],
""),
(
'''[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "San Francisco",
"state": "CA",
"unit": "celsius"
})))
],
''),
(
'''[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_age",
arguments=json.dumps({
"name":
"John Doe",
})))
],
''),
(
'''[TOOL_CALLS][{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}]''', # noqa: E501
[
ToolCall(function=FunctionCall(name="add",
arguments=json.dumps({
"a": 3.5,
"b": 4
}))),
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "San Francisco",
"state": "CA",
"unit": "celsius"
})))
],
''),
],
)
def test_extract_tool_calls_streaming(mistral_tool_parser, mistral_tokenizer,
model_output, expected_tool_calls,
expected_content):
other_content: str = ''
function_names: list[str] = []
function_args_strs: list[str] = []
tool_call_idx: int = -1
tool_call_ids: list[Optional[str]] = []

for delta_message in stream_delta_message_generator(
mistral_tool_parser, mistral_tokenizer, model_output):
# role should never be streamed from tool parser
assert not delta_message.role

if delta_message.content:
other_content += delta_message.content

streamed_tool_calls = delta_message.tool_calls

if streamed_tool_calls and len(streamed_tool_calls) > 0:
# make sure only one diff is present - correct even for parallel
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]

# if a new tool is being called, set up empty arguments
if tool_call.index != tool_call_idx:
tool_call_idx = tool_call.index
function_args_strs.append("")
tool_call_ids.append(None)

# if a tool call ID is streamed, make sure one hasn't been already
if tool_call.id and not tool_call_ids[tool_call.index]:
tool_call_ids[tool_call.index] = tool_call.id

# if parts of the function start being streamed
if tool_call.function:
# if the function name is defined, set it. it should be streamed
# IN ENTIRETY, exactly one time.
if tool_call.function.name:
assert isinstance(tool_call.function.name, str)
function_names.append(tool_call.function.name)

if tool_call.function.arguments:
# make sure they're a string and then add them to the list
assert isinstance(tool_call.function.arguments, str)

function_args_strs[
tool_call.index] += tool_call.function.arguments

assert other_content == expected_content

actual_tool_calls = [
ToolCall(id=tool_call_id,
function=FunctionCall(
name=function_name,
arguments=partial_json_parser.ensure_json(
function_args_str, Allow.OBJ | Allow.STR)))
for tool_call_id, function_name, function_args_str in zip(
tool_call_ids, function_names, function_args_strs)
]
assert_tool_calls(actual_tool_calls, expected_tool_calls)
Loading
Loading