Skip to content

Commit 3c7d942

Browse files
authored
[Frontend] Abstract prompt and SpeechToTextConfig for transcriptions models (#20637)
Signed-off-by: NickLucche <[email protected]>
1 parent 890323d commit 3c7d942

File tree

4 files changed

+141
-60
lines changed

4 files changed

+141
-60
lines changed

vllm/config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4958,3 +4958,34 @@ def get_layers_from_vllm_config(vllm_config: VllmConfig,
49584958
vllm_config.compilation_config.static_forward_context.items()
49594959
if isinstance(layer, layer_type)
49604960
}
4961+
4962+
4963+
@config
4964+
@dataclass
4965+
class SpeechToTextConfig:
4966+
"""Configuration for speech-to-text models."""
4967+
4968+
sample_rate: float = 16_000
4969+
"""Sample rate (Hz) to resample input audio to. Most speech models expect
4970+
16kHz audio input. The input audio will be automatically resampled to this
4971+
rate before processing."""
4972+
4973+
max_audio_clip_s: int = 30
4974+
"""Maximum duration in seconds for a single audio clip without chunking.
4975+
Audio longer than this will be split into smaller chunks if
4976+
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
4977+
4978+
overlap_chunk_second: int = 1
4979+
"""Overlap duration in seconds between consecutive audio chunks when
4980+
splitting long audio. This helps maintain context across chunk boundaries
4981+
and improves transcription quality at split points."""
4982+
4983+
min_energy_split_window_size: Optional[int] = 1600
4984+
"""Window size in samples for finding low-energy (quiet) regions to split
4985+
audio chunks. The algorithm looks for the quietest moment within this
4986+
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
4987+
at 16kHz. If None, no chunking will be done."""
4988+
4989+
@property
4990+
def allow_audio_chunking(self) -> bool:
4991+
return self.min_energy_split_window_size is not None

vllm/entrypoints/openai/speech_to_text.py

Lines changed: 33 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import time
77
from collections.abc import AsyncGenerator
88
from functools import cached_property
9-
from math import ceil
109
from typing import Callable, Literal, Optional, TypeVar, Union, cast
1110

1211
import numpy as np
@@ -28,7 +27,6 @@
2827
from vllm.model_executor.model_loader import get_model_cls
2928
from vllm.model_executor.models import SupportsTranscription
3029
from vllm.outputs import RequestOutput
31-
from vllm.transformers_utils.processor import cached_get_processor
3230
from vllm.utils import PlaceholderModule
3331

3432
try:
@@ -44,9 +42,6 @@
4442
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
4543
# TODO configurable
4644
MAX_AUDIO_CLIP_FILESIZE_MB = 25
47-
MAX_AUDIO_CLIP_SECONDS = 30
48-
OVERLAP_CHUNK_SECOND = 1
49-
MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
5045

5146

5247
class OpenAISpeechToText(OpenAIServing):
@@ -71,63 +66,56 @@ def __init__(
7166

7267
self.default_sampling_params = (
7368
self.model_config.get_diff_sampling_param())
74-
processor = cached_get_processor(model_config.model)
75-
self.max_audio_clip_s = processor.feature_extractor.chunk_length \
76-
if hasattr(processor.feature_extractor, 'chunk_length') \
77-
else MAX_AUDIO_CLIP_SECONDS
78-
self.model_sr = processor.feature_extractor.sampling_rate
79-
self.hop_length = processor.feature_extractor.hop_length
8069
self.task_type = task_type
8170

71+
self.asr_config = self.model_cls.get_speech_to_text_config(
72+
model_config, task_type)
73+
8274
if self.default_sampling_params:
8375
logger.info(
8476
"Overwriting default completion sampling param with: %s",
8577
self.default_sampling_params)
8678

8779
@cached_property
88-
def model_cls(self):
89-
return get_model_cls(self.model_config)
80+
def model_cls(self) -> type[SupportsTranscription]:
81+
model_cls = get_model_cls(self.model_config)
82+
return cast(type[SupportsTranscription], model_cls)
9083

9184
async def _preprocess_speech_to_text(
9285
self,
9386
request: SpeechToTextRequest,
9487
audio_data: bytes,
9588
) -> tuple[list[PromptType], float]:
96-
model_cls = cast(SupportsTranscription, self.model_cls)
97-
9889
# Validate request
9990
# TODO language should be optional and can be guessed.
10091
# For now we default to en. See
10192
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
10293
lang = request.language or "en"
103-
model_cls.validate_language(lang)
94+
self.model_cls.validate_language(lang)
10495

10596
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
10697
raise ValueError("Maximum file size exceeded.")
10798

10899
with io.BytesIO(audio_data) as bytes_:
109100
# NOTE resample to model SR here for efficiency. This is also a
110101
# pre-requisite for chunking, as it assumes Whisper SR.
111-
y, sr = librosa.load(bytes_, sr=self.model_sr)
102+
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
112103

113104
duration = librosa.get_duration(y=y, sr=sr)
114-
chunks = [y
115-
] if duration < self.max_audio_clip_s else self._split_audio(
116-
y, int(sr))
105+
do_split_audio = (self.asr_config.allow_audio_chunking
106+
and duration > self.asr_config.max_audio_clip_s)
107+
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
117108
prompts = []
118109
for chunk in chunks:
119-
prompt = {
120-
"encoder_prompt": {
121-
"prompt": "",
122-
"multi_modal_data": {
123-
"audio": (chunk, sr),
124-
},
125-
},
126-
"decoder_prompt":
127-
model_cls.get_decoder_prompt(lang, self.task_type,
128-
request.prompt)
129-
}
130-
prompts.append(cast(PromptType, prompt))
110+
# The model has control over the construction, as long as it
111+
# returns a valid PromptType.
112+
prompt = self.model_cls.get_generation_prompt(
113+
audio=chunk,
114+
stt_config=self.asr_config,
115+
language=lang,
116+
task_type=self.task_type,
117+
request_prompt=request.prompt)
118+
prompts.append(prompt)
131119
return prompts, duration
132120

133121
async def _create_speech_to_text(
@@ -196,7 +184,8 @@ async def _create_speech_to_text(
196184

197185
self._log_inputs(
198186
request_id,
199-
prompts[0]['decoder_prompt'], # type: ignore
187+
# It will not display special tokens like <|startoftranscript|>
188+
request.prompt,
200189
params=sampling_params,
201190
lora_request=None,
202191
prompt_adapter_request=None)
@@ -261,17 +250,11 @@ async def _speech_to_text_stream_generator(
261250
async for res in result_generator:
262251
# On first result.
263252
if res.prompt_token_ids is not None:
264-
# Do not account the 4-tokens `<|startoftranscript|>..`
265-
# Could be negative when language token
266-
# is not specified.
267-
num_prompt_tokens = max(
268-
len(res.prompt_token_ids) - 4, 0)
269-
# NOTE(NickLucche) user can't pass encoder
270-
# prompts directly at least not to Whisper.
271-
# One indicator of the encoder amount of processing
272-
# is the log-mel spectogram length.
273-
num_prompt_tokens += ceil(
274-
audio_duration_s * self.model_sr / self.hop_length)
253+
num_prompt_tokens = len(res.prompt_token_ids)
254+
if audio_tokens := self.model_cls.get_num_audio_tokens(
255+
audio_duration_s, self.asr_config,
256+
self.model_config):
257+
num_prompt_tokens += audio_tokens
275258

276259
# We need to do it here, because if there are exceptions in
277260
# the result_generator, it needs to be sent as the FIRST
@@ -347,8 +330,8 @@ async def _speech_to_text_stream_generator(
347330

348331
def _split_audio(self, audio_data: np.ndarray,
349332
sample_rate: int) -> list[np.ndarray]:
350-
chunk_size = sample_rate * self.max_audio_clip_s
351-
overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
333+
chunk_size = sample_rate * self.asr_config.max_audio_clip_s
334+
overlap_size = sample_rate * self.asr_config.overlap_chunk_second
352335
chunks = []
353336
i = 0
354337
while i < audio_data.shape[-1]:
@@ -384,10 +367,10 @@ def _find_split_point(self, wav: np.ndarray, start_idx: int,
384367
# Calculate RMS energy in small windows
385368
min_energy = math.inf
386369
quietest_idx = 0
387-
for i in range(0,
388-
len(segment) - MIN_ENERGY_WINDOW_SIZE,
389-
MIN_ENERGY_WINDOW_SIZE):
390-
window = segment[i:i + MIN_ENERGY_WINDOW_SIZE]
370+
min_energy_window = self.asr_config.min_energy_split_window_size
371+
assert min_energy_window is not None
372+
for i in range(0, len(segment) - min_energy_window, min_energy_window):
373+
window = segment[i:i + min_energy_window]
391374
energy = (window**2).mean()**0.5
392375
if energy < min_energy:
393376
quietest_idx = i + start_idx

vllm/model_executor/models/interfaces.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
66
Union, overload, runtime_checkable)
77

8+
import numpy as np
89
import torch
910
from torch import Tensor
1011
from typing_extensions import Self, TypeIs
1112

13+
from vllm.config import ModelConfig, SpeechToTextConfig
1214
from vllm.inputs import TokensPrompt
15+
from vllm.inputs.data import PromptType
1316
from vllm.logger import init_logger
1417
from vllm.model_executor.layers.quantization.base_config import (
1518
QuantizationConfig)
@@ -692,16 +695,39 @@ class SupportsTranscription(Protocol):
692695
supports_transcription: ClassVar[Literal[True]] = True
693696

694697
@classmethod
695-
def get_decoder_prompt(cls, language: str, task_type: str,
696-
prompt: str) -> str:
697-
"""Get the decoder prompt for the ASR model."""
698+
def get_generation_prompt(cls, audio: np.ndarray,
699+
stt_config: SpeechToTextConfig, language: str,
700+
task_type: str,
701+
request_prompt: str) -> PromptType:
702+
"""Get the prompt for the ASR model.
703+
The model has control over the construction, as long as it
704+
returns a valid PromptType."""
698705
...
699706

700707
@classmethod
701708
def validate_language(cls, language: str) -> bool:
702709
"""Check if the model supports a specific ISO639_1 language."""
703710
...
704711

712+
@classmethod
713+
def get_speech_to_text_config(
714+
cls, model_config: ModelConfig,
715+
task_type: Literal["transcribe",
716+
"translate"]) -> SpeechToTextConfig:
717+
"""Get the speech to text config for the ASR model."""
718+
...
719+
720+
@classmethod
721+
def get_num_audio_tokens(cls, audio_duration_s: float,
722+
stt_config: SpeechToTextConfig,
723+
model_config: ModelConfig) -> Optional[int]:
724+
"""
725+
Map from audio duration to number of audio tokens produced by the ASR
726+
model, without running a forward pass.
727+
This is used for estimating the amount of processing for this audio.
728+
"""
729+
return None
730+
705731

706732
@overload
707733
def supports_transcription(

vllm/model_executor/models/whisper.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,20 @@
33

44
import math
55
from collections.abc import Iterable, Mapping, Sequence
6-
from typing import Optional, TypedDict, Union
6+
from typing import Optional, TypedDict, Union, cast
77

8+
import numpy as np
89
import torch
910
from torch import nn
1011
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
1112
WhisperProcessor)
1213
from transformers.models.whisper.modeling_whisper import sinusoids
1314

1415
from vllm.attention import Attention, AttentionType
15-
from vllm.config import CacheConfig, VllmConfig
16+
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
17+
VllmConfig)
1618
from vllm.distributed import get_tensor_model_parallel_world_size
19+
from vllm.inputs.data import PromptType
1720
from vllm.logger import init_logger
1821
from vllm.model_executor.layers.activation import get_act_fn
1922
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -33,6 +36,7 @@
3336
EncDecMultiModalProcessor,
3437
PromptReplacement, PromptUpdate)
3538
from vllm.multimodal.profiling import BaseDummyInputsBuilder
39+
from vllm.transformers_utils.processor import cached_get_processor
3640

3741
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
3842
SupportsTranscription, SupportsV0Only)
@@ -785,11 +789,24 @@ def validate_language(cls, language: str) -> bool:
785789
f"or {list(ISO639_1_OTHER_LANGS.values())}")
786790

787791
@classmethod
788-
def get_decoder_prompt(cls, language: str, task_type: str,
789-
prompt: str) -> str:
790-
return ((f"<|prev|>{prompt}" if prompt else "") +
791-
f"<|startoftranscript|><|{language}|>" +
792-
f"<|{task_type}|><|notimestamps|>")
792+
def get_generation_prompt(cls, audio: np.ndarray,
793+
stt_config: SpeechToTextConfig, language: str,
794+
task_type: str,
795+
request_prompt: str) -> PromptType:
796+
prompt = {
797+
"encoder_prompt": {
798+
# Whisper does not support encoder prompt.
799+
"prompt": "",
800+
"multi_modal_data": {
801+
"audio": (audio, stt_config.sample_rate),
802+
},
803+
},
804+
"decoder_prompt":
805+
((f"<|prev|>{request_prompt}" if request_prompt else "") +
806+
f"<|startoftranscript|><|{language}|>" +
807+
f"<|{task_type}|><|notimestamps|>")
808+
}
809+
return cast(PromptType, prompt)
793810

794811
@classmethod
795812
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@@ -798,6 +815,30 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
798815

799816
raise ValueError("Only audio modality is supported")
800817

818+
@classmethod
819+
def get_speech_to_text_config(cls, model_config: ModelConfig,
820+
task_type: str) -> SpeechToTextConfig:
821+
processor = cached_get_processor(model_config.model)
822+
823+
return SpeechToTextConfig(
824+
max_audio_clip_s=processor.feature_extractor.chunk_length,
825+
sample_rate=processor.feature_extractor.sampling_rate,
826+
)
827+
828+
@classmethod
829+
def get_num_audio_tokens(cls, audio_duration_s: float,
830+
stt_config: SpeechToTextConfig,
831+
model_config: ModelConfig) -> Optional[int]:
832+
processor = cached_get_processor(model_config.model)
833+
hop_length = processor.feature_extractor.hop_length
834+
assert hop_length is not None
835+
# NOTE(NickLucche) user can't pass encoder
836+
# prompts directly at least not to Whisper.
837+
# One indicator of the encoder amount of processing
838+
# is the log-mel spectogram length.
839+
return math.ceil(audio_duration_s * stt_config.sample_rate /
840+
hop_length)
841+
801842
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
802843
super().__init__()
803844
config = vllm_config.model_config.hf_config

0 commit comments

Comments
 (0)