From e2efa8f3283e2786589ecd7997b2bf74a73ccd4c Mon Sep 17 00:00:00 2001 From: m-misiura Date: Thu, 26 Jun 2025 13:48:16 +0100 Subject: [PATCH 01/10] :sparkles: added a new endpoint to extract tokenizer and chat template information Signed-off-by: m-misiura --- tests/entrypoints/openai/test_tokenization.py | 92 +++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 10 ++ vllm/entrypoints/openai/protocol.py | 8 +- .../openai/serving_tokenization.py | 62 ++++++++++++- 4 files changed, 168 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 57dd25fe1b16..d2f37e6c3d67 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -283,3 +283,95 @@ async def test_detokenize( response.raise_for_status() assert response.json() == {"prompt": prompt} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name,tokenizer_name", + [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + indirect=["tokenizer_name"], +) +async def test_get_tokenizer_info_basic( + server: RemoteOpenAIServer, + model_name: str, + tokenizer_name: str, +): + """Test basic tokenizer info endpoint functionality.""" + response = requests.get(server.url_for("get_tokenizer_info")) + response.raise_for_status() + result = response.json() + assert "tokenizer_class" in result + assert isinstance(result["tokenizer_class"], str) + assert result["tokenizer_class"] + + +@pytest.mark.asyncio +async def test_get_tokenizer_info_schema(server: RemoteOpenAIServer): + """Test that the response matches expected schema types.""" + response = requests.get(server.url_for("get_tokenizer_info")) + response.raise_for_status() + result = response.json() + field_types = { + "add_bos_token": bool, + "add_prefix_space": bool, + "clean_up_tokenization_spaces": bool, + "split_special_tokens": bool, + "bos_token": str, + "eos_token": str, + "pad_token": str, + "unk_token": str, + "chat_template": str, + "errors": str, + "model_max_length": int, + "additional_special_tokens": list, + "added_tokens_decoder": dict, + } + for field, expected_type in field_types.items(): + if field in result and result[field] is not None: + assert isinstance(result[field], expected_type), f"{field} should be {expected_type.__name__}" + + +@pytest.mark.asyncio +async def test_get_tokenizer_info_added_tokens_structure(server: RemoteOpenAIServer): + """Test added_tokens_decoder structure if present.""" + response = requests.get(server.url_for("get_tokenizer_info")) + response.raise_for_status() + result = response.json() + added_tokens = result.get("added_tokens_decoder") + if added_tokens: + for token_id, token_info in added_tokens.items(): + assert isinstance(token_id, str), "Token IDs should be strings" + assert isinstance(token_info, dict), "Token info should be a dict" + assert "content" in token_info, "Token info should have content" + assert "special" in token_info, "Token info should have special flag" + assert isinstance(token_info["special"], bool), "Special flag should be boolean" + + +@pytest.mark.asyncio +async def test_get_tokenizer_info_consistency_with_tokenize(server: RemoteOpenAIServer): + """Test that tokenizer info is consistent with tokenization endpoint.""" + info_response = requests.get(server.url_for("get_tokenizer_info")) + info_response.raise_for_status() + info = info_response.json() + tokenize_response = requests.post( + server.url_for("tokenize"), + json={"model": MODEL_NAME, "prompt": "Hello world!"} + ) + tokenize_response.raise_for_status() + tokenize_result = tokenize_response.json() + info_max_len = info.get("model_max_length") + tokenize_max_len = tokenize_result.get("max_model_len") + if info_max_len and tokenize_max_len: + assert info_max_len >= tokenize_max_len, "Info max length should be >= tokenize max length" + + +@pytest.mark.asyncio +async def test_get_tokenizer_info_chat_template(server: RemoteOpenAIServer): + """Test chat template is properly included.""" + response = requests.get(server.url_for("get_tokenizer_info")) + response.raise_for_status() + result = response.json() + chat_template = result.get("chat_template") + if chat_template: + assert isinstance(chat_template, str), "Chat template should be a string" + assert chat_template.strip(), "Chat template should not be empty" \ No newline at end of file diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e3285a9bf76d..523cbffe8fe6 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -73,6 +73,7 @@ ResponsesResponse, ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, + TokenizerInfoResponse, TranscriptionRequest, TranscriptionResponse, TranslationRequest, @@ -523,6 +524,15 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): assert_never(generator) +@router.get("/get_tokenizer_info") +async def get_tokenizer_info(raw_request: Request): + """Get comprehensive tokenizer information.""" + result = await tokenization(raw_request).get_tokenizer_info() + return JSONResponse( + content=result.model_dump(), + status_code=result.code if isinstance(result, ErrorResponse) else 200) + + @router.get("/v1/models") async def show_available_models(raw_request: Request): handler = models(raw_request) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 14b2253d1dba..b0fad469ab5d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -6,7 +6,7 @@ import json import time from http import HTTPStatus -from typing import Annotated, Any, ClassVar, Literal, Optional, Union +from typing import Annotated, Any, ClassVar, Dict, List, Literal, Optional, Union import regex as re import torch @@ -1849,6 +1849,12 @@ class DetokenizeResponse(OpenAIBaseModel): prompt: str +class TokenizerInfoResponse(OpenAIBaseModel): + """Response containing tokenizer configuration equivalent to tokenizer_config.json""" + tokenizer_class: str + model_config = ConfigDict(extra="allow") + + class LoadLoRAAdapterRequest(BaseModel): lora_name: str lora_path: str diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 3db0a71fadd1..e5c522a4992e 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Final, Optional, Union +import re +from functools import lru_cache +from typing import Any, Dict, Final, List, Optional, Tuple, Union import jinja2 from fastapi import Request @@ -17,11 +18,15 @@ ErrorResponse, TokenizeChatRequest, TokenizeRequest, - TokenizeResponse) + TokenizeResponse, + TokenizerInfoResponse) # yapf: enable from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, + encode_tokens) +from vllm.transformers_utils.tokenizers import MistralTokenizer logger = init_logger(__name__) @@ -155,3 +160,54 @@ async def create_detokenize( input_text = prompt_input["prompt"] return DetokenizeResponse(prompt=input_text) + + async def get_tokenizer_info( + self) -> Union[TokenizerInfoResponse, ErrorResponse]: + """Get comprehensive tokenizer information.""" + try: + tokenizer = await self.engine_client.get_tokenizer() + info = TokenizerInfo(tokenizer, self.model_config, + self.chat_template).to_dict() + return TokenizerInfoResponse(**info) + except Exception as e: + return self.create_error_response( + f"Failed to get tokenizer info: {str(e)}") + + +class TokenizerInfo: + + def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig, + chat_template: Optional[str]): + self.tokenizer = tokenizer + self.model_config = model_config + self.chat_template = chat_template + + def to_dict(self) -> Dict[str, Any]: + """Return the tokenizer configuration.""" + return self._get_tokenizer_config() + + # Use the tokenizer's init_kwargs as the base (this contains the original config) + def _get_tokenizer_config(self) -> Dict[str, Any]: + """Get tokenizer configuration directly from the tokenizer object.""" + config = dict(self.tokenizer.init_kwargs) if hasattr(self.tokenizer, 'init_kwargs') and self.tokenizer.init_kwargs else {} + + # Remove file path fields + config.pop('vocab_file', None) + config.pop('merges_file', None) + + config = self._make_json_serializable(config) + config['tokenizer_class'] = self.tokenizer.__class__.__bases__[0].__name__ + if self.chat_template: + config['chat_template'] = self.chat_template + return config + + def _make_json_serializable(self, obj): + """Convert any non-JSON-serializable objects to serializable format.""" + if hasattr(obj, 'content'): + return obj.content + elif isinstance(obj, dict): + return {k: self._make_json_serializable(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._make_json_serializable(item) for item in obj] + else: + return obj \ No newline at end of file From ade7167194c2545ad5413a9a6ebe27d837610444 Mon Sep 17 00:00:00 2001 From: m-misiura Date: Tue, 8 Jul 2025 09:42:20 +0100 Subject: [PATCH 02/10] :construction: made the get_tokenizer_info opt-in; disable unless a flag is passed to vllm serve --- vllm/entrypoints/openai/api_server.py | 19 +++++++++++-------- vllm/entrypoints/openai/cli_args.py | 7 +++++++ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 523cbffe8fe6..0d9dbce98e69 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -524,13 +524,16 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): assert_never(generator) -@router.get("/get_tokenizer_info") -async def get_tokenizer_info(raw_request: Request): - """Get comprehensive tokenizer information.""" - result = await tokenization(raw_request).get_tokenizer_info() - return JSONResponse( - content=result.model_dump(), - status_code=result.code if isinstance(result, ErrorResponse) else 200) +def maybe_register_tokenizer_info_endpoint(args): + """Conditionally register the tokenizer info endpoint if enabled.""" + if getattr(args, 'enable_tokenizer_info_endpoint', False): + @router.get("/get_tokenizer_info") + async def get_tokenizer_info(raw_request: Request): + """Get comprehensive tokenizer information.""" + result = await tokenization(raw_request).get_tokenizer_info() + return JSONResponse( + content=result.model_dump(), + status_code=result.code if isinstance(result, ErrorResponse) else 200) @router.get("/v1/models") @@ -1541,8 +1544,8 @@ async def run_server_worker(listen_address, uvicorn_kwargs['log_config'] = log_config async with build_async_engine_client(args, client_config) as engine_client: + maybe_register_tokenizer_info_endpoint(args) app = build_app(args) - vllm_config = await engine_client.get_vllm_config() await init_app_state(engine_client, vllm_config, app.state, args) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 4f8aaab772fd..0ca09460c0c8 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -295,6 +295,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help= "If set to True, enable tracking server_load_metrics in the app state." ) + parser.add_argument( + "--enable-tokenizer-info-endpoint", + action='store_true', + default=False, + help="Enable the /get_tokenizer_info endpoint. May expose chat " + "templates and other tokenizer configuration." + ) return parser From 855b7bf32050097f4560d151688d791bbcf8d8a4 Mon Sep 17 00:00:00 2001 From: m-misiura Date: Tue, 8 Jul 2025 14:18:04 +0100 Subject: [PATCH 03/10] :construction: renamed endpoint from `get_tokenizer_info` to `tokenizer_info` and ran pre-commit --- tests/entrypoints/openai/test_tokenization.py | 112 ++++++----- vllm/entrypoints/openai/api_server.py | 14 +- vllm/entrypoints/openai/cli_args.py | 5 +- vllm/entrypoints/openai/protocol.py | 187 ++++++++++-------- .../openai/serving_tokenization.py | 110 ++++++----- 5 files changed, 239 insertions(+), 189 deletions(-) diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index d2f37e6c3d67..0980bf5bcab8 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -41,8 +41,8 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811 @pytest.fixture(scope="module") def tokenizer_name(model_name: str, zephyr_lora_added_tokens_files: str): # noqa: F811 - return zephyr_lora_added_tokens_files if ( - model_name == "zephyr-lora2") else model_name + return (zephyr_lora_added_tokens_files if + (model_name == "zephyr-lora2") else model_name) @pytest_asyncio.fixture @@ -69,12 +69,14 @@ async def test_tokenize_completions( prompt = "vllm1 This is a test prompt." tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - response = requests.post(server.url_for("tokenize"), - json={ - "add_special_tokens": add_special, - "model": model_name, - "prompt": prompt - }) + response = requests.post( + server.url_for("tokenize"), + json={ + "add_special_tokens": add_special, + "model": model_name, + "prompt": prompt, + }, + ) response.raise_for_status() result = response.json() @@ -100,16 +102,20 @@ async def test_tokenize_chat( for add_generation in [False, True]: for add_special in [False, True]: - conversation = [{ - "role": "user", - "content": "Hi there!" - }, { - "role": "assistant", - "content": "Nice to meet you!" - }, { - "role": "user", - "content": "Can I ask a question? vllm1" - }] + conversation = [ + { + "role": "user", + "content": "Hi there!" + }, + { + "role": "assistant", + "content": "Nice to meet you!" + }, + { + "role": "user", + "content": "Can I ask a question? vllm1" + }, + ] for continue_final in [False, True]: if add_generation and continue_final: continue @@ -123,20 +129,21 @@ async def test_tokenize_chat( add_generation_prompt=add_generation, continue_final_message=continue_final, conversation=conversation, - tokenize=False) + tokenize=False, + ) tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - response = requests.post(server.url_for("tokenize"), - json={ - "add_generation_prompt": - add_generation, - "continue_final_message": - continue_final, - "add_special_tokens": add_special, - "messages": conversation, - "model": model_name - }) + response = requests.post( + server.url_for("tokenize"), + json={ + "add_generation_prompt": add_generation, + "continue_final_message": continue_final, + "add_special_tokens": add_special, + "messages": conversation, + "model": model_name, + }, + ) response.raise_for_status() result = response.json() @@ -275,11 +282,13 @@ async def test_detokenize( prompt = "This is a test prompt. vllm1" tokens = tokenizer.encode(prompt, add_special_tokens=False) - response = requests.post(server.url_for("detokenize"), - json={ - "model": model_name, - "tokens": tokens - }) + response = requests.post( + server.url_for("detokenize"), + json={ + "model": model_name, + "tokens": tokens + }, + ) response.raise_for_status() assert response.json() == {"prompt": prompt} @@ -302,10 +311,10 @@ async def test_get_tokenizer_info_basic( result = response.json() assert "tokenizer_class" in result assert isinstance(result["tokenizer_class"], str) - assert result["tokenizer_class"] + assert result["tokenizer_class"] -@pytest.mark.asyncio +@pytest.mark.asyncio async def test_get_tokenizer_info_schema(server: RemoteOpenAIServer): """Test that the response matches expected schema types.""" response = requests.get(server.url_for("get_tokenizer_info")) @@ -313,7 +322,7 @@ async def test_get_tokenizer_info_schema(server: RemoteOpenAIServer): result = response.json() field_types = { "add_bos_token": bool, - "add_prefix_space": bool, + "add_prefix_space": bool, "clean_up_tokenization_spaces": bool, "split_special_tokens": bool, "bos_token": str, @@ -328,11 +337,14 @@ async def test_get_tokenizer_info_schema(server: RemoteOpenAIServer): } for field, expected_type in field_types.items(): if field in result and result[field] is not None: - assert isinstance(result[field], expected_type), f"{field} should be {expected_type.__name__}" + assert isinstance( + result[field], + expected_type), (f"{field} should be {expected_type.__name__}") @pytest.mark.asyncio -async def test_get_tokenizer_info_added_tokens_structure(server: RemoteOpenAIServer): +async def test_get_tokenizer_info_added_tokens_structure( + server: RemoteOpenAIServer, ): """Test added_tokens_decoder structure if present.""" response = requests.get(server.url_for("get_tokenizer_info")) response.raise_for_status() @@ -343,26 +355,33 @@ async def test_get_tokenizer_info_added_tokens_structure(server: RemoteOpenAISer assert isinstance(token_id, str), "Token IDs should be strings" assert isinstance(token_info, dict), "Token info should be a dict" assert "content" in token_info, "Token info should have content" - assert "special" in token_info, "Token info should have special flag" - assert isinstance(token_info["special"], bool), "Special flag should be boolean" + assert "special" in token_info, ( + "Token info should have special flag") + assert isinstance(token_info["special"], + bool), ("Special flag should be boolean") @pytest.mark.asyncio -async def test_get_tokenizer_info_consistency_with_tokenize(server: RemoteOpenAIServer): +async def test_get_tokenizer_info_consistency_with_tokenize( + server: RemoteOpenAIServer, ): """Test that tokenizer info is consistent with tokenization endpoint.""" info_response = requests.get(server.url_for("get_tokenizer_info")) info_response.raise_for_status() info = info_response.json() tokenize_response = requests.post( server.url_for("tokenize"), - json={"model": MODEL_NAME, "prompt": "Hello world!"} + json={ + "model": MODEL_NAME, + "prompt": "Hello world!" + }, ) tokenize_response.raise_for_status() tokenize_result = tokenize_response.json() info_max_len = info.get("model_max_length") tokenize_max_len = tokenize_result.get("max_model_len") if info_max_len and tokenize_max_len: - assert info_max_len >= tokenize_max_len, "Info max length should be >= tokenize max length" + assert info_max_len >= tokenize_max_len, ( + "Info max length should be >= tokenize max length") @pytest.mark.asyncio @@ -373,5 +392,6 @@ async def test_get_tokenizer_info_chat_template(server: RemoteOpenAIServer): result = response.json() chat_template = result.get("chat_template") if chat_template: - assert isinstance(chat_template, str), "Chat template should be a string" - assert chat_template.strip(), "Chat template should not be empty" \ No newline at end of file + assert isinstance(chat_template, + str), ("Chat template should be a string") + assert chat_template.strip(), "Chat template should not be empty" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0d9dbce98e69..cb6675cdcbb6 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -73,7 +73,6 @@ ResponsesResponse, ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, - TokenizerInfoResponse, TranscriptionRequest, TranscriptionResponse, TranslationRequest, @@ -527,15 +526,16 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): def maybe_register_tokenizer_info_endpoint(args): """Conditionally register the tokenizer info endpoint if enabled.""" if getattr(args, 'enable_tokenizer_info_endpoint', False): - @router.get("/get_tokenizer_info") + + @router.get("/tokenizer_info") async def get_tokenizer_info(raw_request: Request): """Get comprehensive tokenizer information.""" result = await tokenization(raw_request).get_tokenizer_info() - return JSONResponse( - content=result.model_dump(), - status_code=result.code if isinstance(result, ErrorResponse) else 200) - - + return JSONResponse(content=result.model_dump(), + status_code=result.code if isinstance( + result, ErrorResponse) else 200) + + @router.get("/v1/models") async def show_available_models(raw_request: Request): handler = models(raw_request) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 0ca09460c0c8..bf4d754aec15 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -299,9 +299,8 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--enable-tokenizer-info-endpoint", action='store_true', default=False, - help="Enable the /get_tokenizer_info endpoint. May expose chat " - "templates and other tokenizer configuration." - ) + help="Enable the /tokenizer_info endpoint. May expose chat " + "templates and other tokenizer configuration.") return parser diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b0fad469ab5d..9b7db14534c8 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -6,7 +6,7 @@ import json import time from http import HTTPStatus -from typing import Annotated, Any, ClassVar, Dict, List, Literal, Optional, Union +from typing import Annotated, Any, ClassVar, Literal, Optional, Union import regex as re import torch @@ -129,7 +129,7 @@ class JsonSchemaResponseFormat(OpenAIBaseModel): description: Optional[str] = None # schema is the field in openai but that causes conflicts with pydantic so # instead use json_schema with an alias - json_schema: Optional[dict[str, Any]] = Field(default=None, alias='schema') + json_schema: Optional[dict[str, Any]] = Field(default=None, alias="schema") strict: Optional[bool] = None @@ -200,8 +200,8 @@ def get_logits_processors(processors: Optional[LogitsProcessors], if processors and pattern: logits_processors = [] for processor in processors: - qualname = processor if isinstance(processor, - str) else processor.qualname + qualname = (processor + if isinstance(processor, str) else processor.qualname) if not re.match(pattern, qualname): raise ValueError( f"Logits processor '{qualname}' is not allowed by this " @@ -251,7 +251,7 @@ class ResponsesRequest(OpenAIBaseModel): prompt: Optional[ResponsePrompt] = None reasoning: Optional[Reasoning] = None service_tier: Literal["auto", "default", "flex", "scale", - "priority"] = "auto" + "priority"] = ("auto") store: Optional[bool] = True stream: Optional[bool] = False temperature: Optional[float] = None @@ -356,7 +356,8 @@ class ChatCompletionRequest(OpenAIBaseModel): max_tokens: Optional[int] = Field( default=None, deprecated= - 'max_tokens is deprecated in favor of the max_completion_tokens field') + "max_tokens is deprecated in favor of the max_completion_tokens field", + ) max_completion_tokens: Optional[int] = None n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 @@ -418,7 +419,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ("If this is set, the chat will be formatted so that the final " "message in the chat is open-ended, without any EOS tokens. The " "model will continue this message rather than starting a new one. " - "This allows you to \"prefill\" part of the model's response for it. " + 'This allows you to "prefill" part of the model\'s response for it. ' "Cannot be used at the same time as `add_generation_prompt`."), ) add_special_tokens: bool = Field( @@ -437,7 +438,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "the model if it is performing RAG (retrieval-augmented generation)." " If the template does not support RAG, this argument will have no " "effect. We recommend that each document should be a dict containing " - "\"title\" and \"text\" keys."), + '"title" and "text" keys.'), ) chat_template: Optional[str] = Field( default=None, @@ -518,13 +519,15 @@ class ChatCompletionRequest(OpenAIBaseModel): "'args' and 'kwargs' fields containing positional and keyword " "arguments. For example: {'qualname': " "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " - "{'param': 'value'}}.")) + "{'param': 'value'}}."), + ) return_tokens_as_token_ids: Optional[bool] = Field( default=None, description=( "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified.")) + "that are not JSON-encodable can be identified."), + ) cache_salt: Optional[str] = Field( default=None, description=( @@ -533,10 +536,12 @@ class ChatCompletionRequest(OpenAIBaseModel): "environments. The salt should be random, protected from " "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit). Not supported by vLLM engine V0.")) + "to 256 bit). Not supported by vLLM engine V0."), + ) kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, - description="KVTransfer parameters used for disaggregated serving.") + description="KVTransfer parameters used for disaggregated serving.", + ) vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( default=None, @@ -558,7 +563,6 @@ class ChatCompletionRequest(OpenAIBaseModel): def to_beam_search_params( self, max_tokens: int, default_sampling_params: dict) -> BeamSearchParams: - n = self.n if self.n is not None else 1 if (temperature := self.temperature) is None: temperature = default_sampling_params.get( @@ -579,7 +583,6 @@ def to_sampling_params( logits_processor_pattern: Optional[str], default_sampling_params: dict, ) -> SamplingParams: - # Default parameters if (repetition_penalty := self.repetition_penalty) is None: repetition_penalty = default_sampling_params.get( @@ -657,17 +660,17 @@ def to_sampling_params( logits_processor_pattern), include_stop_str_in_output=self.include_stop_str_in_output, truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=RequestOutputKind.DELTA if self.stream \ - else RequestOutputKind.FINAL_ONLY, + output_kind=RequestOutputKind.DELTA + if self.stream else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, logit_bias=self.logit_bias, - bad_words= self.bad_words, + bad_words=self.bad_words, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, ) def _get_guided_json_from_tool( - self) -> Optional[Union[str, dict, BaseModel]]: + self, ) -> Optional[Union[str, dict, BaseModel]]: # user has chosen to not use any tool if self.tool_choice == "none" or self.tools is None: return None @@ -692,7 +695,7 @@ def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: "properties": { "name": { "type": "string", - "enum": [tool.function.name] + "enum": [tool.function.name], }, # parameters are always generated as '{}' in the final # output if they are missing from the request @@ -702,9 +705,9 @@ def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: if tool.function.parameters else { "type": "object", "properties": {} - } + }, }, - "required": ["name", "parameters"] + "required": ["name", "parameters"], } json_schema = { @@ -712,8 +715,8 @@ def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: "minItems": 1, "items": { "type": "object", - "anyOf": [get_tool_schema(tool) for tool in self.tools] - } + "anyOf": [get_tool_schema(tool) for tool in self.tools], + }, } return json_schema @@ -759,7 +762,7 @@ def check_guided_decoding_count(cls, data): guide_count = sum([ "guided_json" in data and data["guided_json"] is not None, "guided_regex" in data and data["guided_regex"] is not None, - "guided_choice" in data and data["guided_choice"] is not None + "guided_choice" in data and data["guided_choice"] is not None, ]) # you can only use one kind of guided decoding if guide_count > 1: @@ -779,7 +782,6 @@ def check_guided_decoding_count(cls, data): @model_validator(mode="before") @classmethod def check_tool_usage(cls, data): - # if "tool_choice" is not specified but tools are provided, # default to "auto" tool_choice if "tool_choice" not in data and data.get("tools"): @@ -791,7 +793,6 @@ def check_tool_usage(cls, data): # if "tool_choice" is specified -- validation if "tool_choice" in data: - # ensure that if "tool choice" is specified, tools are present if "tools" not in data or data["tools"] is None: raise ValueError( @@ -800,18 +801,18 @@ def check_tool_usage(cls, data): # make sure that tool choice is either a named tool # OR that it's set to "auto" or "required" if data["tool_choice"] not in [ - "auto", "required" + "auto", + "required", ] and not isinstance(data["tool_choice"], dict): raise NotImplementedError( - f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\ - 'Only named tools, "none", "auto" or "required" '\ - 'are supported.' - ) + f"Invalid value for `tool_choice`: {data['tool_choice']}! " + 'Only named tools, "none", "auto" or "required" ' + "are supported.") # ensure that if "tool_choice" is specified as an object, # it matches a valid tool - correct_usage_message = 'Correct usage: `{"type": "function",' \ - ' "function": {"name": "my_function"}}`' + correct_usage_message = ('Correct usage: `{"type": "function",' + ' "function": {"name": "my_function"}}`') if isinstance(data["tool_choice"], dict): valid_tool = False function = data["tool_choice"].get("function") @@ -823,8 +824,8 @@ def check_tool_usage(cls, data): raise ValueError(f"Expected field `name` in `function` in " f"`tool_choice`! {correct_usage_message}") function_name = function["name"] - if not isinstance(function_name, - str) or len(function_name) == 0: + if (not isinstance(function_name, str) + or len(function_name) == 0): raise ValueError( f"Invalid `name` in `function`: `{function_name}`" f" in `tool_choice`! {correct_usage_message}") @@ -855,8 +856,8 @@ def check_cache_salt_support(cls, data): raise ValueError( "Parameter 'cache_salt' is not supported with " "this instance of vLLM, which uses engine V0.") - if not isinstance(data["cache_salt"], - str) or not data["cache_salt"]: + if (not isinstance(data["cache_salt"], str) + or not data["cache_salt"]): raise ValueError("Parameter 'cache_salt' must be a " "non-empty string if provided.") return data @@ -966,18 +967,21 @@ class CompletionRequest(OpenAIBaseModel): "'args' and 'kwargs' fields containing positional and keyword " "arguments. For example: {'qualname': " "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " - "{'param': 'value'}}.")) + "{'param': 'value'}}."), + ) return_tokens_as_token_ids: Optional[bool] = Field( default=None, description=( "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified.")) + "that are not JSON-encodable can be identified."), + ) kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, - description="KVTransfer parameters used for disaggregated serving.") + description="KVTransfer parameters used for disaggregated serving.", + ) vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( default=None, @@ -1001,7 +1005,6 @@ def to_beam_search_params( max_tokens: int, default_sampling_params: Optional[dict] = None, ) -> BeamSearchParams: - if default_sampling_params is None: default_sampling_params = {} n = self.n if self.n is not None else 1 @@ -1024,7 +1027,6 @@ def to_sampling_params( logits_processor_pattern: Optional[str], default_sampling_params: Optional[dict] = None, ) -> SamplingParams: - if default_sampling_params is None: default_sampling_params = {} @@ -1096,13 +1098,13 @@ def to_sampling_params( logits_processors=get_logits_processors(self.logits_processors, logits_processor_pattern), truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=RequestOutputKind.DELTA if self.stream \ - else RequestOutputKind.FINAL_ONLY, + output_kind=RequestOutputKind.DELTA + if self.stream else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, logit_bias=self.logit_bias, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, - ) + ) @model_validator(mode="before") @classmethod @@ -1110,7 +1112,7 @@ def check_guided_decoding_count(cls, data): guide_count = sum([ "guided_json" in data and data["guided_json"] is not None, "guided_regex" in data and data["guided_regex"] is not None, - "guided_choice" in data and data["guided_choice"] is not None + "guided_choice" in data and data["guided_choice"] is not None, ]) if guide_count > 1: raise ValueError( @@ -1281,8 +1283,10 @@ class ScoreRequest(OpenAIBaseModel): # --8<-- [end:score-extra-params] def to_pooling_params(self, *, use_cross_encoder: bool = False): - return PoolingParams(use_cross_encoder=use_cross_encoder, - additional_data=self.additional_data) + return PoolingParams( + use_cross_encoder=use_cross_encoder, + additional_data=self.additional_data, + ) class RerankRequest(OpenAIBaseModel): @@ -1308,8 +1312,10 @@ class RerankRequest(OpenAIBaseModel): # --8<-- [end:rerank-extra-params] def to_pooling_params(self, *, use_cross_encoder: bool = False): - return PoolingParams(use_cross_encoder=use_cross_encoder, - additional_data=self.additional_data) + return PoolingParams( + use_cross_encoder=use_cross_encoder, + additional_data=self.additional_data, + ) class RerankDocument(BaseModel): @@ -1700,7 +1706,7 @@ class BatchRequestInput(OpenAIBaseModel): # The parameters of the request. body: BatchRequestInputBody - @field_validator('body', mode='plain') + @field_validator("body", mode="plain") @classmethod def check_type_for_url(cls, value: Any, info: ValidationInfo): # Use url to disambiguate models @@ -1724,8 +1730,12 @@ class BatchResponseData(OpenAIBaseModel): request_id: str # The body of the response. - body: Optional[Union[ChatCompletionResponse, EmbeddingResponse, - ScoreResponse, RerankResponse]] = None + body: Optional[Union[ + ChatCompletionResponse, + EmbeddingResponse, + ScoreResponse, + RerankResponse, + ]] = None class BatchRequestOutput(OpenAIBaseModel): @@ -1785,7 +1795,7 @@ class TokenizeChatRequest(OpenAIBaseModel): ("If this is set, the chat will be formatted so that the final " "message in the chat is open-ended, without any EOS tokens. The " "model will continue this message rather than starting a new one. " - "This allows you to \"prefill\" part of the model's response for it. " + 'This allows you to "prefill" part of the model\'s response for it. ' "Cannot be used at the same time as `add_generation_prompt`."), ) add_special_tokens: bool = Field( @@ -1850,11 +1860,15 @@ class DetokenizeResponse(OpenAIBaseModel): class TokenizerInfoResponse(OpenAIBaseModel): - """Response containing tokenizer configuration equivalent to tokenizer_config.json""" + """ + Response containing tokenizer configuration + equivalent to tokenizer_config.json + """ + tokenizer_class: str model_config = ConfigDict(extra="allow") - - + + class LoadLoRAAdapterRequest(BaseModel): lora_name: str lora_path: str @@ -1980,10 +1994,10 @@ class TranscriptionRequest(OpenAIBaseModel): } def to_sampling_params( - self, - default_max_tokens: int, - default_sampling_params: Optional[dict] = None) -> SamplingParams: - + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None, + ) -> SamplingParams: max_tokens = default_max_tokens if default_sampling_params is None: @@ -2006,21 +2020,23 @@ def to_sampling_params( if (repetition_penalty := self.repetition_penalty) is None: repetition_penalty = default_sampling_params.get( "repetition_penalty", - self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"]) - - return SamplingParams.from_optional(temperature=temperature, - max_tokens=max_tokens, - seed=self.seed, - top_p=top_p, - top_k=top_k, - min_p=min_p, - frequency_penalty=self.frequency_penalty, - repetition_penalty=repetition_penalty, - presence_penalty=self.presence_penalty, - output_kind=RequestOutputKind.DELTA - if self.stream \ - else RequestOutputKind.FINAL_ONLY, - extra_args=self.vllm_xargs) + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], + ) + + return SamplingParams.from_optional( + temperature=temperature, + max_tokens=max_tokens, + seed=self.seed, + top_p=top_p, + top_k=top_k, + min_p=min_p, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + presence_penalty=self.presence_penalty, + output_kind=RequestOutputKind.DELTA + if self.stream else RequestOutputKind.FINAL_ONLY, + extra_args=self.vllm_xargs, + ) @model_validator(mode="before") @classmethod @@ -2196,10 +2212,10 @@ class TranslationRequest(OpenAIBaseModel): } def to_sampling_params( - self, - default_max_tokens: int, - default_sampling_params: Optional[dict] = None) -> SamplingParams: - + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None, + ) -> SamplingParams: max_tokens = default_max_tokens if default_sampling_params is None: @@ -2209,11 +2225,12 @@ def to_sampling_params( temperature = default_sampling_params.get( "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) - return SamplingParams.from_optional(temperature=temperature, - max_tokens=max_tokens, - output_kind=RequestOutputKind.DELTA - if self.stream \ - else RequestOutputKind.FINAL_ONLY) + return SamplingParams.from_optional( + temperature=temperature, + max_tokens=max_tokens, + output_kind=RequestOutputKind.DELTA + if self.stream else RequestOutputKind.FINAL_ONLY, + ) @model_validator(mode="before") @classmethod diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index e5c522a4992e..2d9741460600 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import re -from functools import lru_cache -from typing import Any, Dict, Final, List, Optional, Tuple, Union +from typing import Any, Final, Optional, Union import jinja2 from fastapi import Request @@ -24,9 +22,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, - encode_tokens) -from vllm.transformers_utils.tokenizers import MistralTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) @@ -43,10 +39,12 @@ def __init__( chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, ) -> None: - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger) + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + ) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format @@ -91,39 +89,45 @@ async def create_tokenize( add_special_tokens=request.add_special_tokens, ) else: - (request_prompts, - engine_prompts) = await self._preprocess_completion( - request, - tokenizer, - request.prompt, - add_special_tokens=request.add_special_tokens, - ) + ( + request_prompts, + engine_prompts, + ) = await self._preprocess_completion( + request, + tokenizer, + request.prompt, + add_special_tokens=request.add_special_tokens, + ) except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") input_ids: list[int] = [] for i, engine_prompt in enumerate(engine_prompts): - self._log_inputs(request_id, - request_prompts[i], - params=None, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + self._log_inputs( + request_id, + request_prompts[i], + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) # Silently ignore prompt adapter since it does not affect # tokenization (Unlike in Embeddings API where an error is raised) - if isinstance(engine_prompt, - dict) and "prompt_token_ids" in engine_prompt: + if (isinstance(engine_prompt, dict) + and "prompt_token_ids" in engine_prompt): input_ids.extend(engine_prompt["prompt_token_ids"]) token_strs = None if request.return_token_strs: token_strs = tokenizer.convert_ids_to_tokens(input_ids) - return TokenizeResponse(tokens=input_ids, - token_strs=token_strs, - count=len(input_ids), - max_model_len=self.max_model_len) + return TokenizeResponse( + tokens=input_ids, + token_strs=token_strs, + count=len(input_ids), + max_model_len=self.max_model_len, + ) async def create_detokenize( self, @@ -143,11 +147,13 @@ async def create_detokenize( tokenizer = await self.engine_client.get_tokenizer(lora_request) - self._log_inputs(request_id, - request.tokens, - params=None, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + self._log_inputs( + request_id, + request.tokens, + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) # Silently ignore prompt adapter since it does not affect tokenization # (Unlike in Embeddings API where an error is raised) @@ -162,7 +168,7 @@ async def create_detokenize( return DetokenizeResponse(prompt=input_text) async def get_tokenizer_info( - self) -> Union[TokenizerInfoResponse, ErrorResponse]: + self, ) -> Union[TokenizerInfoResponse, ErrorResponse]: """Get comprehensive tokenizer information.""" try: tokenizer = await self.engine_client.get_tokenizer() @@ -176,38 +182,46 @@ async def get_tokenizer_info( class TokenizerInfo: - def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig, - chat_template: Optional[str]): + def __init__( + self, + tokenizer: AnyTokenizer, + model_config: ModelConfig, + chat_template: Optional[str], + ): self.tokenizer = tokenizer self.model_config = model_config self.chat_template = chat_template - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Return the tokenizer configuration.""" return self._get_tokenizer_config() - # Use the tokenizer's init_kwargs as the base (this contains the original config) - def _get_tokenizer_config(self) -> Dict[str, Any]: + # Use the tokenizer's init_kwargs as the base + # (this contains the original config) + def _get_tokenizer_config(self) -> dict[str, Any]: """Get tokenizer configuration directly from the tokenizer object.""" - config = dict(self.tokenizer.init_kwargs) if hasattr(self.tokenizer, 'init_kwargs') and self.tokenizer.init_kwargs else {} - + config = (dict(self.tokenizer.init_kwargs) + if hasattr(self.tokenizer, "init_kwargs") + and self.tokenizer.init_kwargs else {}) + # Remove file path fields - config.pop('vocab_file', None) - config.pop('merges_file', None) - + config.pop("vocab_file", None) + config.pop("merges_file", None) + config = self._make_json_serializable(config) - config['tokenizer_class'] = self.tokenizer.__class__.__bases__[0].__name__ + config["tokenizer_class"] = self.tokenizer.__class__.__bases__[ + 0].__name__ if self.chat_template: - config['chat_template'] = self.chat_template + config["chat_template"] = self.chat_template return config def _make_json_serializable(self, obj): """Convert any non-JSON-serializable objects to serializable format.""" - if hasattr(obj, 'content'): + if hasattr(obj, "content"): return obj.content elif isinstance(obj, dict): return {k: self._make_json_serializable(v) for k, v in obj.items()} elif isinstance(obj, list): return [self._make_json_serializable(item) for item in obj] else: - return obj \ No newline at end of file + return obj From 8ff1d1b0c881ec756995c53bf5357af44cc1643e Mon Sep 17 00:00:00 2001 From: m-misiura Date: Wed, 9 Jul 2025 16:08:34 +0100 Subject: [PATCH 04/10] :art: formatting changes --- tests/entrypoints/openai/test_tokenization.py | 79 ++++---- vllm/entrypoints/openai/api_server.py | 1 + vllm/entrypoints/openai/cli_args.py | 1 - vllm/entrypoints/openai/protocol.py | 175 ++++++++---------- .../openai/serving_tokenization.py | 65 +++---- 5 files changed, 145 insertions(+), 176 deletions(-) diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 0980bf5bcab8..badbbd50a2ff 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -41,8 +41,8 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811 @pytest.fixture(scope="module") def tokenizer_name(model_name: str, zephyr_lora_added_tokens_files: str): # noqa: F811 - return (zephyr_lora_added_tokens_files if - (model_name == "zephyr-lora2") else model_name) + return zephyr_lora_added_tokens_files if ( + model_name == "zephyr-lora2") else model_name @pytest_asyncio.fixture @@ -69,14 +69,12 @@ async def test_tokenize_completions( prompt = "vllm1 This is a test prompt." tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - response = requests.post( - server.url_for("tokenize"), - json={ - "add_special_tokens": add_special, - "model": model_name, - "prompt": prompt, - }, - ) + response = requests.post(server.url_for("tokenize"), + json={ + "add_special_tokens": add_special, + "model": model_name, + "prompt": prompt + }) response.raise_for_status() result = response.json() @@ -102,20 +100,16 @@ async def test_tokenize_chat( for add_generation in [False, True]: for add_special in [False, True]: - conversation = [ - { - "role": "user", - "content": "Hi there!" - }, - { - "role": "assistant", - "content": "Nice to meet you!" - }, - { - "role": "user", - "content": "Can I ask a question? vllm1" - }, - ] + conversation = [{ + "role": "user", + "content": "Hi there!" + }, { + "role": "assistant", + "content": "Nice to meet you!" + }, { + "role": "user", + "content": "Can I ask a question? vllm1" + }] for continue_final in [False, True]: if add_generation and continue_final: continue @@ -129,21 +123,20 @@ async def test_tokenize_chat( add_generation_prompt=add_generation, continue_final_message=continue_final, conversation=conversation, - tokenize=False, - ) + tokenize=False) tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - response = requests.post( - server.url_for("tokenize"), - json={ - "add_generation_prompt": add_generation, - "continue_final_message": continue_final, - "add_special_tokens": add_special, - "messages": conversation, - "model": model_name, - }, - ) + response = requests.post(server.url_for("tokenize"), + json={ + "add_generation_prompt": + add_generation, + "continue_final_message": + continue_final, + "add_special_tokens": add_special, + "messages": conversation, + "model": model_name + }) response.raise_for_status() result = response.json() @@ -282,13 +275,11 @@ async def test_detokenize( prompt = "This is a test prompt. vllm1" tokens = tokenizer.encode(prompt, add_special_tokens=False) - response = requests.post( - server.url_for("detokenize"), - json={ - "model": model_name, - "tokens": tokens - }, - ) + response = requests.post(server.url_for("detokenize"), + json={ + "model": model_name, + "tokens": tokens + }) response.raise_for_status() assert response.json() == {"prompt": prompt} @@ -394,4 +385,4 @@ async def test_get_tokenizer_info_chat_template(server: RemoteOpenAIServer): if chat_template: assert isinstance(chat_template, str), ("Chat template should be a string") - assert chat_template.strip(), "Chat template should not be empty" + assert chat_template.strip(), "Chat template should not be empty" \ No newline at end of file diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index cb6675cdcbb6..2d43d07eeb43 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1546,6 +1546,7 @@ async def run_server_worker(listen_address, async with build_async_engine_client(args, client_config) as engine_client: maybe_register_tokenizer_info_endpoint(args) app = build_app(args) + vllm_config = await engine_client.get_vllm_config() await init_app_state(engine_client, vllm_config, app.state, args) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index bf4d754aec15..63e5df3f715c 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -301,7 +301,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=False, help="Enable the /tokenizer_info endpoint. May expose chat " "templates and other tokenizer configuration.") - return parser diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 9b7db14534c8..4bbd19c65b22 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -129,7 +129,7 @@ class JsonSchemaResponseFormat(OpenAIBaseModel): description: Optional[str] = None # schema is the field in openai but that causes conflicts with pydantic so # instead use json_schema with an alias - json_schema: Optional[dict[str, Any]] = Field(default=None, alias="schema") + json_schema: Optional[dict[str, Any]] = Field(default=None, alias='schema') strict: Optional[bool] = None @@ -200,8 +200,8 @@ def get_logits_processors(processors: Optional[LogitsProcessors], if processors and pattern: logits_processors = [] for processor in processors: - qualname = (processor - if isinstance(processor, str) else processor.qualname) + qualname = processor if isinstance(processor, + str) else processor.qualname if not re.match(pattern, qualname): raise ValueError( f"Logits processor '{qualname}' is not allowed by this " @@ -251,7 +251,7 @@ class ResponsesRequest(OpenAIBaseModel): prompt: Optional[ResponsePrompt] = None reasoning: Optional[Reasoning] = None service_tier: Literal["auto", "default", "flex", "scale", - "priority"] = ("auto") + "priority"] = "auto" store: Optional[bool] = True stream: Optional[bool] = False temperature: Optional[float] = None @@ -356,8 +356,7 @@ class ChatCompletionRequest(OpenAIBaseModel): max_tokens: Optional[int] = Field( default=None, deprecated= - "max_tokens is deprecated in favor of the max_completion_tokens field", - ) + 'max_tokens is deprecated in favor of the max_completion_tokens field') max_completion_tokens: Optional[int] = None n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 @@ -419,7 +418,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ("If this is set, the chat will be formatted so that the final " "message in the chat is open-ended, without any EOS tokens. The " "model will continue this message rather than starting a new one. " - 'This allows you to "prefill" part of the model\'s response for it. ' + "This allows you to \"prefill\" part of the model's response for it. " "Cannot be used at the same time as `add_generation_prompt`."), ) add_special_tokens: bool = Field( @@ -438,7 +437,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "the model if it is performing RAG (retrieval-augmented generation)." " If the template does not support RAG, this argument will have no " "effect. We recommend that each document should be a dict containing " - '"title" and "text" keys.'), + "\"title\" and \"text\" keys."), ) chat_template: Optional[str] = Field( default=None, @@ -519,15 +518,13 @@ class ChatCompletionRequest(OpenAIBaseModel): "'args' and 'kwargs' fields containing positional and keyword " "arguments. For example: {'qualname': " "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " - "{'param': 'value'}}."), - ) + "{'param': 'value'}}.")) return_tokens_as_token_ids: Optional[bool] = Field( default=None, description=( "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified."), - ) + "that are not JSON-encodable can be identified.")) cache_salt: Optional[str] = Field( default=None, description=( @@ -536,12 +533,10 @@ class ChatCompletionRequest(OpenAIBaseModel): "environments. The salt should be random, protected from " "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit). Not supported by vLLM engine V0."), - ) + "to 256 bit). Not supported by vLLM engine V0.")) kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, - description="KVTransfer parameters used for disaggregated serving.", - ) + description="KVTransfer parameters used for disaggregated serving.") vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( default=None, @@ -563,6 +558,7 @@ class ChatCompletionRequest(OpenAIBaseModel): def to_beam_search_params( self, max_tokens: int, default_sampling_params: dict) -> BeamSearchParams: + n = self.n if self.n is not None else 1 if (temperature := self.temperature) is None: temperature = default_sampling_params.get( @@ -583,6 +579,7 @@ def to_sampling_params( logits_processor_pattern: Optional[str], default_sampling_params: dict, ) -> SamplingParams: + # Default parameters if (repetition_penalty := self.repetition_penalty) is None: repetition_penalty = default_sampling_params.get( @@ -660,17 +657,17 @@ def to_sampling_params( logits_processor_pattern), include_stop_str_in_output=self.include_stop_str_in_output, truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=RequestOutputKind.DELTA - if self.stream else RequestOutputKind.FINAL_ONLY, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, logit_bias=self.logit_bias, - bad_words=self.bad_words, + bad_words= self.bad_words, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, ) def _get_guided_json_from_tool( - self, ) -> Optional[Union[str, dict, BaseModel]]: + self) -> Optional[Union[str, dict, BaseModel]]: # user has chosen to not use any tool if self.tool_choice == "none" or self.tools is None: return None @@ -695,7 +692,7 @@ def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: "properties": { "name": { "type": "string", - "enum": [tool.function.name], + "enum": [tool.function.name] }, # parameters are always generated as '{}' in the final # output if they are missing from the request @@ -705,9 +702,9 @@ def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: if tool.function.parameters else { "type": "object", "properties": {} - }, + } }, - "required": ["name", "parameters"], + "required": ["name", "parameters"] } json_schema = { @@ -715,8 +712,8 @@ def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: "minItems": 1, "items": { "type": "object", - "anyOf": [get_tool_schema(tool) for tool in self.tools], - }, + "anyOf": [get_tool_schema(tool) for tool in self.tools] + } } return json_schema @@ -762,7 +759,7 @@ def check_guided_decoding_count(cls, data): guide_count = sum([ "guided_json" in data and data["guided_json"] is not None, "guided_regex" in data and data["guided_regex"] is not None, - "guided_choice" in data and data["guided_choice"] is not None, + "guided_choice" in data and data["guided_choice"] is not None ]) # you can only use one kind of guided decoding if guide_count > 1: @@ -782,6 +779,7 @@ def check_guided_decoding_count(cls, data): @model_validator(mode="before") @classmethod def check_tool_usage(cls, data): + # if "tool_choice" is not specified but tools are provided, # default to "auto" tool_choice if "tool_choice" not in data and data.get("tools"): @@ -793,6 +791,7 @@ def check_tool_usage(cls, data): # if "tool_choice" is specified -- validation if "tool_choice" in data: + # ensure that if "tool choice" is specified, tools are present if "tools" not in data or data["tools"] is None: raise ValueError( @@ -801,18 +800,18 @@ def check_tool_usage(cls, data): # make sure that tool choice is either a named tool # OR that it's set to "auto" or "required" if data["tool_choice"] not in [ - "auto", - "required", + "auto", "required" ] and not isinstance(data["tool_choice"], dict): raise NotImplementedError( - f"Invalid value for `tool_choice`: {data['tool_choice']}! " - 'Only named tools, "none", "auto" or "required" ' - "are supported.") + f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\ + 'Only named tools, "none", "auto" or "required" '\ + 'are supported.' + ) # ensure that if "tool_choice" is specified as an object, # it matches a valid tool - correct_usage_message = ('Correct usage: `{"type": "function",' - ' "function": {"name": "my_function"}}`') + correct_usage_message = 'Correct usage: `{"type": "function",' \ + ' "function": {"name": "my_function"}}`' if isinstance(data["tool_choice"], dict): valid_tool = False function = data["tool_choice"].get("function") @@ -824,8 +823,8 @@ def check_tool_usage(cls, data): raise ValueError(f"Expected field `name` in `function` in " f"`tool_choice`! {correct_usage_message}") function_name = function["name"] - if (not isinstance(function_name, str) - or len(function_name) == 0): + if not isinstance(function_name, + str) or len(function_name) == 0: raise ValueError( f"Invalid `name` in `function`: `{function_name}`" f" in `tool_choice`! {correct_usage_message}") @@ -856,8 +855,8 @@ def check_cache_salt_support(cls, data): raise ValueError( "Parameter 'cache_salt' is not supported with " "this instance of vLLM, which uses engine V0.") - if (not isinstance(data["cache_salt"], str) - or not data["cache_salt"]): + if not isinstance(data["cache_salt"], + str) or not data["cache_salt"]: raise ValueError("Parameter 'cache_salt' must be a " "non-empty string if provided.") return data @@ -967,21 +966,18 @@ class CompletionRequest(OpenAIBaseModel): "'args' and 'kwargs' fields containing positional and keyword " "arguments. For example: {'qualname': " "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " - "{'param': 'value'}}."), - ) + "{'param': 'value'}}.")) return_tokens_as_token_ids: Optional[bool] = Field( default=None, description=( "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified."), - ) + "that are not JSON-encodable can be identified.")) kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, - description="KVTransfer parameters used for disaggregated serving.", - ) + description="KVTransfer parameters used for disaggregated serving.") vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( default=None, @@ -1005,6 +1001,7 @@ def to_beam_search_params( max_tokens: int, default_sampling_params: Optional[dict] = None, ) -> BeamSearchParams: + if default_sampling_params is None: default_sampling_params = {} n = self.n if self.n is not None else 1 @@ -1027,6 +1024,7 @@ def to_sampling_params( logits_processor_pattern: Optional[str], default_sampling_params: Optional[dict] = None, ) -> SamplingParams: + if default_sampling_params is None: default_sampling_params = {} @@ -1098,13 +1096,13 @@ def to_sampling_params( logits_processors=get_logits_processors(self.logits_processors, logits_processor_pattern), truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=RequestOutputKind.DELTA - if self.stream else RequestOutputKind.FINAL_ONLY, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, guided_decoding=guided_decoding, logit_bias=self.logit_bias, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, - ) + ) @model_validator(mode="before") @classmethod @@ -1112,7 +1110,7 @@ def check_guided_decoding_count(cls, data): guide_count = sum([ "guided_json" in data and data["guided_json"] is not None, "guided_regex" in data and data["guided_regex"] is not None, - "guided_choice" in data and data["guided_choice"] is not None, + "guided_choice" in data and data["guided_choice"] is not None ]) if guide_count > 1: raise ValueError( @@ -1283,10 +1281,8 @@ class ScoreRequest(OpenAIBaseModel): # --8<-- [end:score-extra-params] def to_pooling_params(self, *, use_cross_encoder: bool = False): - return PoolingParams( - use_cross_encoder=use_cross_encoder, - additional_data=self.additional_data, - ) + return PoolingParams(use_cross_encoder=use_cross_encoder, + additional_data=self.additional_data) class RerankRequest(OpenAIBaseModel): @@ -1312,10 +1308,8 @@ class RerankRequest(OpenAIBaseModel): # --8<-- [end:rerank-extra-params] def to_pooling_params(self, *, use_cross_encoder: bool = False): - return PoolingParams( - use_cross_encoder=use_cross_encoder, - additional_data=self.additional_data, - ) + return PoolingParams(use_cross_encoder=use_cross_encoder, + additional_data=self.additional_data) class RerankDocument(BaseModel): @@ -1706,7 +1700,7 @@ class BatchRequestInput(OpenAIBaseModel): # The parameters of the request. body: BatchRequestInputBody - @field_validator("body", mode="plain") + @field_validator('body', mode='plain') @classmethod def check_type_for_url(cls, value: Any, info: ValidationInfo): # Use url to disambiguate models @@ -1730,12 +1724,8 @@ class BatchResponseData(OpenAIBaseModel): request_id: str # The body of the response. - body: Optional[Union[ - ChatCompletionResponse, - EmbeddingResponse, - ScoreResponse, - RerankResponse, - ]] = None + body: Optional[Union[ChatCompletionResponse, EmbeddingResponse, + ScoreResponse, RerankResponse]] = None class BatchRequestOutput(OpenAIBaseModel): @@ -1795,7 +1785,7 @@ class TokenizeChatRequest(OpenAIBaseModel): ("If this is set, the chat will be formatted so that the final " "message in the chat is open-ended, without any EOS tokens. The " "model will continue this message rather than starting a new one. " - 'This allows you to "prefill" part of the model\'s response for it. ' + "This allows you to \"prefill\" part of the model's response for it. " "Cannot be used at the same time as `add_generation_prompt`."), ) add_special_tokens: bool = Field( @@ -1994,10 +1984,10 @@ class TranscriptionRequest(OpenAIBaseModel): } def to_sampling_params( - self, - default_max_tokens: int, - default_sampling_params: Optional[dict] = None, - ) -> SamplingParams: + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None) -> SamplingParams: + max_tokens = default_max_tokens if default_sampling_params is None: @@ -2020,23 +2010,21 @@ def to_sampling_params( if (repetition_penalty := self.repetition_penalty) is None: repetition_penalty = default_sampling_params.get( "repetition_penalty", - self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], - ) - - return SamplingParams.from_optional( - temperature=temperature, - max_tokens=max_tokens, - seed=self.seed, - top_p=top_p, - top_k=top_k, - min_p=min_p, - frequency_penalty=self.frequency_penalty, - repetition_penalty=repetition_penalty, - presence_penalty=self.presence_penalty, - output_kind=RequestOutputKind.DELTA - if self.stream else RequestOutputKind.FINAL_ONLY, - extra_args=self.vllm_xargs, - ) + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"]) + + return SamplingParams.from_optional(temperature=temperature, + max_tokens=max_tokens, + seed=self.seed, + top_p=top_p, + top_k=top_k, + min_p=min_p, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + presence_penalty=self.presence_penalty, + output_kind=RequestOutputKind.DELTA + if self.stream \ + else RequestOutputKind.FINAL_ONLY, + extra_args=self.vllm_xargs) @model_validator(mode="before") @classmethod @@ -2212,10 +2200,10 @@ class TranslationRequest(OpenAIBaseModel): } def to_sampling_params( - self, - default_max_tokens: int, - default_sampling_params: Optional[dict] = None, - ) -> SamplingParams: + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None) -> SamplingParams: + max_tokens = default_max_tokens if default_sampling_params is None: @@ -2225,12 +2213,11 @@ def to_sampling_params( temperature = default_sampling_params.get( "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) - return SamplingParams.from_optional( - temperature=temperature, - max_tokens=max_tokens, - output_kind=RequestOutputKind.DELTA - if self.stream else RequestOutputKind.FINAL_ONLY, - ) + return SamplingParams.from_optional(temperature=temperature, + max_tokens=max_tokens, + output_kind=RequestOutputKind.DELTA + if self.stream \ + else RequestOutputKind.FINAL_ONLY) @model_validator(mode="before") @classmethod diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 2d9741460600..9468f0a83fa3 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + from typing import Any, Final, Optional, Union import jinja2 @@ -39,12 +40,10 @@ def __init__( chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, ) -> None: - super().__init__( - engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - ) + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format @@ -89,45 +88,39 @@ async def create_tokenize( add_special_tokens=request.add_special_tokens, ) else: - ( - request_prompts, - engine_prompts, - ) = await self._preprocess_completion( - request, - tokenizer, - request.prompt, - add_special_tokens=request.add_special_tokens, - ) + (request_prompts, + engine_prompts) = await self._preprocess_completion( + request, + tokenizer, + request.prompt, + add_special_tokens=request.add_special_tokens, + ) except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") input_ids: list[int] = [] for i, engine_prompt in enumerate(engine_prompts): - self._log_inputs( - request_id, - request_prompts[i], - params=None, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) + self._log_inputs(request_id, + request_prompts[i], + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) # Silently ignore prompt adapter since it does not affect # tokenization (Unlike in Embeddings API where an error is raised) - if (isinstance(engine_prompt, dict) - and "prompt_token_ids" in engine_prompt): + if isinstance(engine_prompt, + dict) and "prompt_token_ids" in engine_prompt: input_ids.extend(engine_prompt["prompt_token_ids"]) token_strs = None if request.return_token_strs: token_strs = tokenizer.convert_ids_to_tokens(input_ids) - return TokenizeResponse( - tokens=input_ids, - token_strs=token_strs, - count=len(input_ids), - max_model_len=self.max_model_len, - ) + return TokenizeResponse(tokens=input_ids, + token_strs=token_strs, + count=len(input_ids), + max_model_len=self.max_model_len) async def create_detokenize( self, @@ -147,13 +140,11 @@ async def create_detokenize( tokenizer = await self.engine_client.get_tokenizer(lora_request) - self._log_inputs( - request_id, - request.tokens, - params=None, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) + self._log_inputs(request_id, + request.tokens, + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) # Silently ignore prompt adapter since it does not affect tokenization # (Unlike in Embeddings API where an error is raised) From 167fd626be7d7478f2ae9f1adec36e403b639993 Mon Sep 17 00:00:00 2001 From: m-misiura Date: Wed, 9 Jul 2025 16:40:49 +0100 Subject: [PATCH 05/10] :recycle: `TokenizerInfo` is now a dataclass to reduce boiler plate --- .../openai/serving_tokenization.py | 20 +++++-------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 9468f0a83fa3..8e994a792b58 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +from dataclasses import dataclass from typing import Any, Final, Optional, Union import jinja2 @@ -163,32 +163,22 @@ async def get_tokenizer_info( """Get comprehensive tokenizer information.""" try: tokenizer = await self.engine_client.get_tokenizer() - info = TokenizerInfo(tokenizer, self.model_config, - self.chat_template).to_dict() + info = TokenizerInfo(tokenizer, self.chat_template).to_dict() return TokenizerInfoResponse(**info) except Exception as e: return self.create_error_response( f"Failed to get tokenizer info: {str(e)}") +@dataclass class TokenizerInfo: - - def __init__( - self, - tokenizer: AnyTokenizer, - model_config: ModelConfig, - chat_template: Optional[str], - ): - self.tokenizer = tokenizer - self.model_config = model_config - self.chat_template = chat_template + tokenizer: AnyTokenizer + chat_template: Optional[str] def to_dict(self) -> dict[str, Any]: """Return the tokenizer configuration.""" return self._get_tokenizer_config() - # Use the tokenizer's init_kwargs as the base - # (this contains the original config) def _get_tokenizer_config(self) -> dict[str, Any]: """Get tokenizer configuration directly from the tokenizer object.""" config = (dict(self.tokenizer.init_kwargs) From 3b2ea85a4d9e83c8af866d67407e539fe39ad3ed Mon Sep 17 00:00:00 2001 From: m-misiura Date: Wed, 9 Jul 2025 16:56:11 +0100 Subject: [PATCH 06/10] :construction: simplifying tokenizer_class extraction --- vllm/entrypoints/openai/serving_tokenization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 8e994a792b58..cdaf57cf2135 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -190,8 +190,7 @@ def _get_tokenizer_config(self) -> dict[str, Any]: config.pop("merges_file", None) config = self._make_json_serializable(config) - config["tokenizer_class"] = self.tokenizer.__class__.__bases__[ - 0].__name__ + config["tokenizer_class"] = type(self.tokenizer).__name__ if self.chat_template: config["chat_template"] = self.chat_template return config From 0e0c04e2d42999f65700f15ddef8bc319b46e970 Mon Sep 17 00:00:00 2001 From: m-misiura Date: Thu, 10 Jul 2025 09:24:02 +0100 Subject: [PATCH 07/10] :construction: move ConfigDict to top of TokenizerInfoResponse as class-level attribute --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 4bbd19c65b22..c53f2f996950 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1855,8 +1855,8 @@ class TokenizerInfoResponse(OpenAIBaseModel): equivalent to tokenizer_config.json """ - tokenizer_class: str model_config = ConfigDict(extra="allow") + tokenizer_class: str class LoadLoRAAdapterRequest(BaseModel): From 8eb54614ac096ca0fbf44c85a41b1612f5513028 Mon Sep 17 00:00:00 2001 From: m-misiura Date: Thu, 10 Jul 2025 11:06:03 +0100 Subject: [PATCH 08/10] :recycle: make config more pythonic --- vllm/entrypoints/openai/serving_tokenization.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index cdaf57cf2135..8181b36ed0ba 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -181,9 +181,7 @@ def to_dict(self) -> dict[str, Any]: def _get_tokenizer_config(self) -> dict[str, Any]: """Get tokenizer configuration directly from the tokenizer object.""" - config = (dict(self.tokenizer.init_kwargs) - if hasattr(self.tokenizer, "init_kwargs") - and self.tokenizer.init_kwargs else {}) + config = dict(getattr(self.tokenizer, "init_kwargs", None) or {}) # Remove file path fields config.pop("vocab_file", None) From a42e7e9bbf94f2e98c194e8a2fbeec1adff43c4e Mon Sep 17 00:00:00 2001 From: m-misiura Date: Fri, 11 Jul 2025 10:03:21 +0100 Subject: [PATCH 09/10] :construction: updated test_tokenization tests to reflect that `tokenizer_info` endpoint is optional; also reflected name change of the endpoint in the tests --- tests/entrypoints/openai/test_tokenization.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index badbbd50a2ff..0dbbdfbfd24a 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -32,6 +32,7 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811 f"zephyr-lora2={zephyr_lora_added_tokens_files}", "--max-lora-rank", "64", + "--enable-tokenizer-info-endpoint", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -291,13 +292,13 @@ async def test_detokenize( [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], indirect=["tokenizer_name"], ) -async def test_get_tokenizer_info_basic( +async def test_tokenizer_info_basic( server: RemoteOpenAIServer, model_name: str, tokenizer_name: str, ): """Test basic tokenizer info endpoint functionality.""" - response = requests.get(server.url_for("get_tokenizer_info")) + response = requests.get(server.url_for("tokenizer_info")) response.raise_for_status() result = response.json() assert "tokenizer_class" in result @@ -306,9 +307,9 @@ async def test_get_tokenizer_info_basic( @pytest.mark.asyncio -async def test_get_tokenizer_info_schema(server: RemoteOpenAIServer): +async def test_tokenizer_info_schema(server: RemoteOpenAIServer): """Test that the response matches expected schema types.""" - response = requests.get(server.url_for("get_tokenizer_info")) + response = requests.get(server.url_for("tokenizer_info")) response.raise_for_status() result = response.json() field_types = { @@ -334,10 +335,10 @@ async def test_get_tokenizer_info_schema(server: RemoteOpenAIServer): @pytest.mark.asyncio -async def test_get_tokenizer_info_added_tokens_structure( +async def test_tokenizer_info_added_tokens_structure( server: RemoteOpenAIServer, ): """Test added_tokens_decoder structure if present.""" - response = requests.get(server.url_for("get_tokenizer_info")) + response = requests.get(server.url_for("tokenizer_info")) response.raise_for_status() result = response.json() added_tokens = result.get("added_tokens_decoder") @@ -353,10 +354,10 @@ async def test_get_tokenizer_info_added_tokens_structure( @pytest.mark.asyncio -async def test_get_tokenizer_info_consistency_with_tokenize( +async def test_tokenizer_info_consistency_with_tokenize( server: RemoteOpenAIServer, ): """Test that tokenizer info is consistent with tokenization endpoint.""" - info_response = requests.get(server.url_for("get_tokenizer_info")) + info_response = requests.get(server.url_for("tokenizer_info")) info_response.raise_for_status() info = info_response.json() tokenize_response = requests.post( @@ -376,9 +377,9 @@ async def test_get_tokenizer_info_consistency_with_tokenize( @pytest.mark.asyncio -async def test_get_tokenizer_info_chat_template(server: RemoteOpenAIServer): +async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer): """Test chat template is properly included.""" - response = requests.get(server.url_for("get_tokenizer_info")) + response = requests.get(server.url_for("tokenizer_info")) response.raise_for_status() result = response.json() chat_template = result.get("chat_template") From a69b705483bba9a3065d51ac78ad797b6b295af4 Mon Sep 17 00:00:00 2001 From: m-misiura Date: Wed, 16 Jul 2025 11:59:29 +0100 Subject: [PATCH 10/10] :construction: added the cli arg for the tokenizer info endpoint --- vllm/entrypoints/openai/cli_args.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index bccce73b79f8..6456d009b957 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -182,6 +182,9 @@ class FrontendArgs: """If set to True, enable tracking server_load_metrics in the app state.""" enable_force_include_usage: bool = False """If set to True, including usage on every request.""" + enable_tokenizer_info_endpoint: bool = False + """Enable the /get_tokenizer_info endpoint. May expose chat + templates and other tokenizer configuration.""" @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: