Skip to content

Commit 73647d9

Browse files
sanchit-gandhigemini-code-assist[bot]
authored andcommitted
[Bugfix] Relax lang pin for voxtral (vllm-project#21833)
Signed-off-by: Sanchit Gandhi <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: jingyu <[email protected]>
1 parent 2d6c1b2 commit 73647d9

File tree

4 files changed

+80
-80
lines changed

4 files changed

+80
-80
lines changed

vllm/entrypoints/openai/speech_to_text.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,7 @@ async def _preprocess_speech_to_text(
8686
audio_data: bytes,
8787
) -> tuple[list[PromptType], float]:
8888
# Validate request
89-
# TODO language should be optional and can be guessed.
90-
# For now we default to en. See
91-
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
92-
lang = request.language or "en"
93-
self.model_cls.validate_language(lang)
89+
language = self.model_cls.validate_language(request.language)
9490

9591
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
9692
raise ValueError("Maximum file size exceeded.")
@@ -112,7 +108,7 @@ async def _preprocess_speech_to_text(
112108
audio=chunk,
113109
stt_config=self.asr_config,
114110
model_config=self.model_config,
115-
language=lang,
111+
language=language,
116112
task_type=self.task_type,
117113
request_prompt=request.prompt)
118114
prompts.append(prompt)

vllm/model_executor/models/interfaces.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from collections.abc import Iterable, MutableSequence
4+
from collections.abc import Iterable, Mapping, MutableSequence
55
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
66
Union, overload, runtime_checkable)
77

88
import numpy as np
99
import torch
1010
from torch import Tensor
11+
from transformers.models.whisper.tokenization_whisper import LANGUAGES
1112
from typing_extensions import Self, TypeIs
1213

1314
from vllm.config import ModelConfig, SpeechToTextConfig
@@ -685,6 +686,8 @@ def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
685686
@runtime_checkable
686687
class SupportsTranscription(Protocol):
687688
"""The interface required for all models that support transcription."""
689+
# Mapping from ISO639_1 language codes: language names
690+
supported_languages: ClassVar[Mapping[str, str]]
688691

689692
supports_transcription: ClassVar[Literal[True]] = True
690693

@@ -694,21 +697,59 @@ class SupportsTranscription(Protocol):
694697
`True`.
695698
"""
696699

700+
def __init_subclass__(cls, **kwargs):
701+
super().__init_subclass__(**kwargs)
702+
# language codes in supported_languages
703+
# that don't exist in the full language map
704+
invalid = set(cls.supported_languages) - set(LANGUAGES.keys())
705+
if invalid:
706+
raise ValueError(
707+
f"{cls.__name__}.supported_languages contains invalid "
708+
f"language codes: {sorted(invalid)}\n. "
709+
f"Valid choices are: {sorted(LANGUAGES.keys())}")
710+
697711
@classmethod
698712
def get_generation_prompt(cls, audio: np.ndarray,
699713
stt_config: SpeechToTextConfig,
700-
model_config: ModelConfig, language: str,
701-
task_type: str,
714+
model_config: ModelConfig,
715+
language: Optional[str], task_type: str,
702716
request_prompt: str) -> PromptType:
703717
"""Get the prompt for the ASR model.
704718
The model has control over the construction, as long as it
705719
returns a valid PromptType."""
706720
...
707721

708722
@classmethod
709-
def validate_language(cls, language: str) -> bool:
710-
"""Check if the model supports a specific ISO639_1 language."""
711-
...
723+
def get_other_languages(cls) -> Mapping[str, str]:
724+
# other possible language codes from the whisper map
725+
return {
726+
k: v
727+
for k, v in LANGUAGES.items() if k not in cls.supported_languages
728+
}
729+
730+
@classmethod
731+
def validate_language(cls, language: Optional[str]) -> Optional[str]:
732+
"""
733+
Ensure the language specified in the transcription request
734+
is a valid ISO 639-1 language code. If the request language is
735+
valid, but not natively supported by the model, trigger a
736+
warning (but not an exception).
737+
"""
738+
if language is None or language in cls.supported_languages:
739+
return language
740+
elif language in cls.get_other_languages():
741+
logger.warning(
742+
"Language %r is not natively supported by %s; "
743+
"results may be less accurate. Supported languages: %r",
744+
language,
745+
cls.__name__,
746+
list(cls.supported_languages.keys()),
747+
)
748+
return language
749+
else:
750+
raise ValueError(
751+
f"Unsupported language: {language!r}. Must be one of "
752+
f"{list(cls.supported_languages.keys())}.")
712753

713754
@classmethod
714755
def get_speech_to_text_config(

vllm/model_executor/models/voxtral.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2727
from vllm.model_executor.models import SupportsPP
2828
# yapf: disable
29-
from vllm.model_executor.models.whisper import (
30-
WhisperEncoder, WhisperForConditionalGeneration)
29+
from vllm.model_executor.models.whisper import WhisperEncoder
3130
# yapf: enable
3231
from vllm.model_executor.sampling_metadata import SamplingMetadata
3332
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -50,6 +49,18 @@
5049

5150
logger = init_logger(__name__)
5251

52+
ISO639_1_SUPPORTED_LANGS = {
53+
"ar": "Arabic",
54+
"nl": "Dutch",
55+
"en": "English",
56+
"fr": "French",
57+
"de": "German",
58+
"hi": "Hindi",
59+
"it": "Italian",
60+
"pt": "Portuguese",
61+
"es": "Spanish",
62+
}
63+
5364

5465
class VoxtralProcessorAdapter:
5566
"""
@@ -301,6 +312,7 @@ def _get_data_parser(self) -> MultiModalDataParser:
301312
dummy_inputs=VoxtralDummyInputsBuilder)
302313
class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
303314
SupportsPP, SupportsTranscription):
315+
supported_languages = ISO639_1_SUPPORTED_LANGS
304316

305317
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
306318
super().__init__()
@@ -441,8 +453,8 @@ def get_speech_to_text_config(cls, model_config: ModelConfig,
441453
# for speech-to-text transcription
442454
def get_generation_prompt(cls, audio: np.ndarray,
443455
model_config: ModelConfig,
444-
stt_config: SpeechToTextConfig, language: str,
445-
task_type: str,
456+
stt_config: SpeechToTextConfig,
457+
language: Optional[str], task_type: str,
446458
request_prompt: str) -> PromptType:
447459
tokenizer = cached_tokenizer_from_config(model_config)
448460
audio = Audio(audio, int(stt_config.sample_rate),
@@ -457,11 +469,6 @@ def get_generation_prompt(cls, audio: np.ndarray,
457469
prompts_dict["prompt_token_ids"] = tokenized.tokens
458470
return cast(PromptType, prompts_dict)
459471

460-
@classmethod
461-
def validate_language(cls, language: str) -> bool:
462-
# same as whisper
463-
return WhisperForConditionalGeneration.validate_language(language)
464-
465472
@classmethod
466473
def get_num_audio_tokens(cls, audio_duration_s: float,
467474
stt_config: SpeechToTextConfig,

vllm/model_executor/models/whisper.py

Lines changed: 15 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -109,51 +109,6 @@
109109
"vi": "Vietnamese",
110110
"cy": "Welsh"
111111
}
112-
ISO639_1_OTHER_LANGS = {
113-
"lo": "Lao",
114-
"jw": "Javanese",
115-
"tk": "Turkmen",
116-
"yi": "Yiddish",
117-
"so": "Somali",
118-
"bn": "Bengali",
119-
"nn": "Norwegian Nynorsk",
120-
"si": "Sinhala",
121-
"yo": "Yoruba",
122-
"sa": "Sanskrit",
123-
"mi": "Māori",
124-
"fo": "Faroese", # codespell:ignore
125-
"mt": "Maltese",
126-
"tg": "Tajik",
127-
"mg": "Malagasy",
128-
"haw": "Hawaiian",
129-
"km": "Khmer",
130-
"br": "Breton",
131-
"ps": "Pashto",
132-
"ln": "Lingala",
133-
"la": "Latin",
134-
"ml": "Malayalam",
135-
"sq": "Albanian",
136-
"su": "Sundanese",
137-
"eu": "Basque",
138-
"ka": "Georgian",
139-
"uz": "Uzbek",
140-
"sn": "Shona",
141-
"ht": "Haitian",
142-
"as": "Assamese",
143-
"mn": "Mongolian",
144-
"te": "Telugu",
145-
"pa": "Panjabi",
146-
"tt": "Tatar",
147-
"gu": "Gujarati",
148-
"oc": "Occitan",
149-
"ha": "Hausa",
150-
"ba": "Bashkir",
151-
"my": "Burmese",
152-
"sd": "Sindhi",
153-
"am": "Amharic",
154-
"lb": "Luxembourgish",
155-
"bo": "Tibetan"
156-
}
157112

158113

159114
class WhisperAudioInputs(TypedDict):
@@ -807,32 +762,33 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
807762

808763
# Whisper only supports audio-conditioned generation.
809764
supports_transcription_only = True
765+
supported_languages = ISO639_1_SUPPORTED_LANGS
810766

811767
@classmethod
812-
def validate_language(cls, language: str) -> bool:
813-
if language in ISO639_1_SUPPORTED_LANGS:
814-
return True
815-
elif language in ISO639_1_OTHER_LANGS:
768+
def validate_language(cls, language: Optional[str]) -> Optional[str]:
769+
if language is None:
770+
# TODO language should be optional and can be guessed.
771+
# For now we default to en. See
772+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
816773
logger.warning(
817-
"The selected language %s has limited accuracy with"
818-
" reported WER>=0.5. Results may be less accurate "
819-
"for this choice.", language)
820-
return True
821-
else:
822-
raise ValueError(f"Unsupported language: {language}."
823-
"Language should be one of:" +
824-
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
825-
f"or {list(ISO639_1_OTHER_LANGS.values())}")
774+
"Defaulting to language='en'. If you wish to transcribe "
775+
"audio in a different language, pass the `language` field "
776+
"in the TranscriptionRequest.")
777+
language = "en"
778+
return super().validate_language(language)
826779

827780
@classmethod
828781
def get_generation_prompt(
829782
cls,
830783
audio: np.ndarray,
831784
model_config: ModelConfig, # not needed here
832785
stt_config: SpeechToTextConfig,
833-
language: str,
786+
language: Optional[str],
834787
task_type: str,
835788
request_prompt: str) -> PromptType:
789+
if language is None:
790+
raise ValueError(
791+
"Language must be specified when creating the Whisper prompt")
836792
prompt = {
837793
"encoder_prompt": {
838794
# Whisper does not support encoder prompt.

0 commit comments

Comments
 (0)