Skip to content

Commit 58b7f0f

Browse files
NickLucchepatrickvonplaten
authored andcommitted
[Frontend] Abstract prompt and SpeechToTextConfig for transcriptions models (vllm-project#20637)
Signed-off-by: NickLucche <[email protected]>
1 parent a2cb4a7 commit 58b7f0f

File tree

4 files changed

+180
-106
lines changed

4 files changed

+180
-106
lines changed

vllm/config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4958,3 +4958,34 @@ def get_layers_from_vllm_config(vllm_config: VllmConfig,
49584958
vllm_config.compilation_config.static_forward_context.items()
49594959
if isinstance(layer, layer_type)
49604960
}
4961+
4962+
4963+
@config
4964+
@dataclass
4965+
class SpeechToTextConfig:
4966+
"""Configuration for speech-to-text models."""
4967+
4968+
sample_rate: float = 16_000
4969+
"""Sample rate (Hz) to resample input audio to. Most speech models expect
4970+
16kHz audio input. The input audio will be automatically resampled to this
4971+
rate before processing."""
4972+
4973+
max_audio_clip_s: int = 30
4974+
"""Maximum duration in seconds for a single audio clip without chunking.
4975+
Audio longer than this will be split into smaller chunks if
4976+
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
4977+
4978+
overlap_chunk_second: int = 1
4979+
"""Overlap duration in seconds between consecutive audio chunks when
4980+
splitting long audio. This helps maintain context across chunk boundaries
4981+
and improves transcription quality at split points."""
4982+
4983+
min_energy_split_window_size: Optional[int] = 1600
4984+
"""Window size in samples for finding low-energy (quiet) regions to split
4985+
audio chunks. The algorithm looks for the quietest moment within this
4986+
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
4987+
at 16kHz. If None, no chunking will be done."""
4988+
4989+
@property
4990+
def allow_audio_chunking(self) -> bool:
4991+
return self.min_energy_split_window_size is not None

vllm/entrypoints/openai/speech_to_text.py

Lines changed: 45 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,8 @@
66
import time
77
from collections.abc import AsyncGenerator
88
from functools import cached_property
9-
from math import ceil
109
from typing import Callable, Literal, Optional, TypeVar, Union, cast
1110

12-
from mistral_common.audio import Audio
13-
from mistral_common.protocol.transcription.request import TranscriptionRequest
14-
from mistral_common.protocol.instruct.messages import AudioChunk
1511
import numpy as np
1612
from fastapi import Request
1713

@@ -31,8 +27,6 @@
3127
from vllm.model_executor.model_loader import get_model_cls
3228
from vllm.model_executor.models import SupportsTranscription
3329
from vllm.outputs import RequestOutput
34-
from vllm.transformers_utils.processor import cached_get_processor
35-
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
3630
from vllm.utils import PlaceholderModule
3731

3832
try:
@@ -48,9 +42,6 @@
4842
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
4943
# TODO configurable
5044
MAX_AUDIO_CLIP_FILESIZE_MB = 25
51-
MAX_AUDIO_CLIP_SECONDS = 30
52-
OVERLAP_CHUNK_SECOND = 1
53-
MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
5445

5546

5647
class OpenAISpeechToText(OpenAIServing):
@@ -75,88 +66,57 @@ def __init__(
7566

7667
self.default_sampling_params = (
7768
self.model_config.get_diff_sampling_param())
78-
79-
self.tokenizer = engine_client.processor.input_preprocessor.tokenizer.tokenizer
80-
if isinstance(self.tokenizer, MistralTokenizer):
81-
audio_encoder = self.tokenizer.instruct.audio_encoder
82-
self.max_audio_clip_s = None
83-
self.model_sr = audio_encoder.audio_config.sampling_rate
84-
self.hop_length = None
85-
else:
86-
processor = cached_get_processor(model_config.model)
87-
self.audio_encoder = None
88-
self.max_audio_clip_s = processor.feature_extractor.chunk_length \
89-
if hasattr(processor.feature_extractor, 'chunk_length') \
90-
else MAX_AUDIO_CLIP_SECONDS
91-
self.model_sr = processor.feature_extractor.sampling_rate
92-
self.hop_length = processor.feature_extractor.hop_length
93-
9469
self.task_type = task_type
9570

71+
self.asr_config = self.model_cls.get_speech_to_text_config(
72+
model_config, task_type)
73+
9674
if self.default_sampling_params:
9775
logger.info(
9876
"Overwriting default completion sampling param with: %s",
9977
self.default_sampling_params)
10078

10179
@cached_property
102-
def model_cls(self):
103-
return get_model_cls(self.model_config)
80+
def model_cls(self) -> type[SupportsTranscription]:
81+
model_cls = get_model_cls(self.model_config)
82+
return cast(type[SupportsTranscription], model_cls)
10483

10584
async def _preprocess_speech_to_text(
10685
self,
10786
request: SpeechToTextRequest,
10887
audio_data: bytes,
10988
) -> tuple[list[PromptType], float]:
110-
model_cls = cast(SupportsTranscription, self.model_cls)
111-
11289
# Validate request
11390
# TODO language should be optional and can be guessed.
11491
# For now we default to en. See
11592
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
11693
lang = request.language or "en"
117-
model_cls.validate_language(lang)
94+
self.model_cls.validate_language(lang)
11895

11996
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
12097
raise ValueError("Maximum file size exceeded.")
12198

122-
prompts = []
123-
if not isinstance(self.tokenizer, MistralTokenizer):
124-
with io.BytesIO(audio_data) as bytes_:
125-
# NOTE resample to model SR here for efficiency. This is also a
126-
# pre-requisite for chunking, as it assumes Whisper SR.
127-
y, sr = librosa.load(bytes_, sr=self.model_sr)
128-
129-
duration = librosa.get_duration(y=y, sr=sr)
130-
131-
chunks = [y
132-
] if duration < self.max_audio_clip_s else self._split_audio(
133-
y, int(sr))
134-
for chunk in chunks:
135-
prompt = {
136-
"encoder_prompt": {
137-
"prompt": "",
138-
"multi_modal_data": {
139-
"audio": (chunk, sr),
140-
},
141-
},
142-
"decoder_prompt":
143-
model_cls.get_decoder_prompt(lang, self.task_type,
144-
request.prompt)
145-
}
146-
prompts.append(cast(PromptType, prompt))
147-
else:
148-
oai_request_dict = request.model_dump()
149-
with io.BytesIO(audio_data) as bytes_:
150-
oai_request_dict["file"] = bytes_
151-
req = TranscriptionRequest.from_openai(oai_request_dict)
152-
153-
duration = req.audio.input_audio.duration
154-
tokenized = self.tokenizer.instruct.encode_transcription(req)
155-
audio = (tokenized.audios[0].audio_array, self.model_sr)
156-
prompts_dict = {"multi_modal_data": {"audio": audio}}
157-
prompts_dict["prompt_token_ids"] = tokenized.tokens
158-
prompts = [cast(PromptType, prompts_dict)]
99+
with io.BytesIO(audio_data) as bytes_:
100+
# NOTE resample to model SR here for efficiency. This is also a
101+
# pre-requisite for chunking, as it assumes Whisper SR.
102+
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
159103

104+
duration = librosa.get_duration(y=y, sr=sr)
105+
do_split_audio = (self.asr_config.allow_audio_chunking
106+
and duration > self.asr_config.max_audio_clip_s)
107+
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
108+
prompts = []
109+
for chunk in chunks:
110+
# The model has control over the construction, as long as it
111+
# returns a valid PromptType.
112+
prompt = self.model_cls.get_generation_prompt(
113+
audio=chunk,
114+
stt_config=self.asr_config,
115+
model_config=self.model_config,
116+
language=lang,
117+
task_type=self.task_type,
118+
request_prompt=request.prompt)
119+
prompts.append(prompt)
160120
return prompts, duration
161121

162122
async def _create_speech_to_text(
@@ -223,13 +183,13 @@ async def _create_speech_to_text(
223183
sampling_params = request.to_sampling_params(
224184
default_max_tokens, self.default_sampling_params)
225185

226-
if "decoder_prompt" in prompts[0]:
227-
self._log_inputs(
228-
request_id,
229-
prompts[0]['decoder_prompt'], # type: ignore
230-
params=sampling_params,
231-
lora_request=None,
232-
prompt_adapter_request=None)
186+
self._log_inputs(
187+
request_id,
188+
# It will not display special tokens like <|startoftranscript|>
189+
request.prompt,
190+
params=sampling_params,
191+
lora_request=None,
192+
prompt_adapter_request=None)
233193

234194
list_result_generator = [
235195
self.engine_client.generate(
@@ -291,17 +251,11 @@ async def _speech_to_text_stream_generator(
291251
async for res in result_generator:
292252
# On first result.
293253
if res.prompt_token_ids is not None:
294-
# Do not account the 4-tokens `<|startoftranscript|>..`
295-
# Could be negative when language token
296-
# is not specified.
297-
num_prompt_tokens = max(
298-
len(res.prompt_token_ids) - 4, 0)
299-
# NOTE(NickLucche) user can't pass encoder
300-
# prompts directly at least not to Whisper.
301-
# One indicator of the encoder amount of processing
302-
# is the log-mel spectogram length.
303-
num_prompt_tokens += ceil(
304-
audio_duration_s * self.model_sr / self.hop_length)
254+
num_prompt_tokens = len(res.prompt_token_ids)
255+
if audio_tokens := self.model_cls.get_num_audio_tokens(
256+
audio_duration_s, self.asr_config,
257+
self.model_config):
258+
num_prompt_tokens += audio_tokens
305259

306260
# We need to do it here, because if there are exceptions in
307261
# the result_generator, it needs to be sent as the FIRST
@@ -377,8 +331,8 @@ async def _speech_to_text_stream_generator(
377331

378332
def _split_audio(self, audio_data: np.ndarray,
379333
sample_rate: int) -> list[np.ndarray]:
380-
chunk_size = sample_rate * self.max_audio_clip_s
381-
overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
334+
chunk_size = sample_rate * self.asr_config.max_audio_clip_s
335+
overlap_size = sample_rate * self.asr_config.overlap_chunk_second
382336
chunks = []
383337
i = 0
384338
while i < audio_data.shape[-1]:
@@ -414,10 +368,10 @@ def _find_split_point(self, wav: np.ndarray, start_idx: int,
414368
# Calculate RMS energy in small windows
415369
min_energy = math.inf
416370
quietest_idx = 0
417-
for i in range(0,
418-
len(segment) - MIN_ENERGY_WINDOW_SIZE,
419-
MIN_ENERGY_WINDOW_SIZE):
420-
window = segment[i:i + MIN_ENERGY_WINDOW_SIZE]
371+
min_energy_window = self.asr_config.min_energy_split_window_size
372+
assert min_energy_window is not None
373+
for i in range(0, len(segment) - min_energy_window, min_energy_window):
374+
window = segment[i:i + min_energy_window]
421375
energy = (window**2).mean()**0.5
422376
if energy < min_energy:
423377
quietest_idx = i + start_idx

vllm/model_executor/models/interfaces.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
66
Union, overload, runtime_checkable)
77

8+
import numpy as np
89
import torch
910
from torch import Tensor
1011
from typing_extensions import Self, TypeIs
1112

13+
from vllm.config import ModelConfig, SpeechToTextConfig
1214
from vllm.inputs import TokensPrompt
15+
from vllm.inputs.data import PromptType
1316
from vllm.logger import init_logger
1417
from vllm.model_executor.layers.quantization.base_config import (
1518
QuantizationConfig)
@@ -692,16 +695,39 @@ class SupportsTranscription(Protocol):
692695
supports_transcription: ClassVar[Literal[True]] = True
693696

694697
@classmethod
695-
def get_decoder_prompt(cls, language: str, task_type: str,
696-
prompt: str) -> str:
697-
"""Get the decoder prompt for the ASR model."""
698+
def get_generation_prompt(cls, audio: np.ndarray,
699+
stt_config: SpeechToTextConfig, language: str,
700+
task_type: str,
701+
request_prompt: str) -> PromptType:
702+
"""Get the prompt for the ASR model.
703+
The model has control over the construction, as long as it
704+
returns a valid PromptType."""
698705
...
699706

700707
@classmethod
701708
def validate_language(cls, language: str) -> bool:
702709
"""Check if the model supports a specific ISO639_1 language."""
703710
...
704711

712+
@classmethod
713+
def get_speech_to_text_config(
714+
cls, model_config: ModelConfig,
715+
task_type: Literal["transcribe",
716+
"translate"]) -> SpeechToTextConfig:
717+
"""Get the speech to text config for the ASR model."""
718+
...
719+
720+
@classmethod
721+
def get_num_audio_tokens(cls, audio_duration_s: float,
722+
stt_config: SpeechToTextConfig,
723+
model_config: ModelConfig) -> Optional[int]:
724+
"""
725+
Map from audio duration to number of audio tokens produced by the ASR
726+
model, without running a forward pass.
727+
This is used for estimating the amount of processing for this audio.
728+
"""
729+
return None
730+
705731

706732
@overload
707733
def supports_transcription(

0 commit comments

Comments
 (0)