6
6
import time
7
7
from collections .abc import AsyncGenerator
8
8
from functools import cached_property
9
- from math import ceil
10
9
from typing import Callable , Literal , Optional , TypeVar , Union , cast
11
10
12
11
import numpy as np
28
27
from vllm .model_executor .model_loader import get_model_cls
29
28
from vllm .model_executor .models import SupportsTranscription
30
29
from vllm .outputs import RequestOutput
31
- from vllm .transformers_utils .processor import cached_get_processor
32
30
from vllm .utils import PlaceholderModule
33
31
34
32
try :
44
42
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
45
43
# TODO configurable
46
44
MAX_AUDIO_CLIP_FILESIZE_MB = 25
47
- MAX_AUDIO_CLIP_SECONDS = 30
48
- OVERLAP_CHUNK_SECOND = 1
49
- MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
50
45
51
46
52
47
class OpenAISpeechToText (OpenAIServing ):
@@ -71,63 +66,56 @@ def __init__(
71
66
72
67
self .default_sampling_params = (
73
68
self .model_config .get_diff_sampling_param ())
74
- processor = cached_get_processor (model_config .model )
75
- self .max_audio_clip_s = processor .feature_extractor .chunk_length \
76
- if hasattr (processor .feature_extractor , 'chunk_length' ) \
77
- else MAX_AUDIO_CLIP_SECONDS
78
- self .model_sr = processor .feature_extractor .sampling_rate
79
- self .hop_length = processor .feature_extractor .hop_length
80
69
self .task_type = task_type
81
70
71
+ self .asr_config = self .model_cls .get_speech_to_text_config (
72
+ model_config , task_type )
73
+
82
74
if self .default_sampling_params :
83
75
logger .info (
84
76
"Overwriting default completion sampling param with: %s" ,
85
77
self .default_sampling_params )
86
78
87
79
@cached_property
88
- def model_cls (self ):
89
- return get_model_cls (self .model_config )
80
+ def model_cls (self ) -> type [SupportsTranscription ]:
81
+ model_cls = get_model_cls (self .model_config )
82
+ return cast (type [SupportsTranscription ], model_cls )
90
83
91
84
async def _preprocess_speech_to_text (
92
85
self ,
93
86
request : SpeechToTextRequest ,
94
87
audio_data : bytes ,
95
88
) -> tuple [list [PromptType ], float ]:
96
- model_cls = cast (SupportsTranscription , self .model_cls )
97
-
98
89
# Validate request
99
90
# TODO language should be optional and can be guessed.
100
91
# For now we default to en. See
101
92
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
102
93
lang = request .language or "en"
103
- model_cls .validate_language (lang )
94
+ self . model_cls .validate_language (lang )
104
95
105
96
if len (audio_data ) / 1024 ** 2 > MAX_AUDIO_CLIP_FILESIZE_MB :
106
97
raise ValueError ("Maximum file size exceeded." )
107
98
108
99
with io .BytesIO (audio_data ) as bytes_ :
109
100
# NOTE resample to model SR here for efficiency. This is also a
110
101
# pre-requisite for chunking, as it assumes Whisper SR.
111
- y , sr = librosa .load (bytes_ , sr = self .model_sr )
102
+ y , sr = librosa .load (bytes_ , sr = self .asr_config . sample_rate )
112
103
113
104
duration = librosa .get_duration (y = y , sr = sr )
114
- chunks = [ y
115
- ] if duration < self . max_audio_clip_s else self ._split_audio (
116
- y , int (sr ))
105
+ do_split_audio = ( self . asr_config . allow_audio_chunking
106
+ and duration > self .asr_config . max_audio_clip_s )
107
+ chunks = [ y ] if not do_split_audio else self . _split_audio ( y , int (sr ))
117
108
prompts = []
118
109
for chunk in chunks :
119
- prompt = {
120
- "encoder_prompt" : {
121
- "prompt" : "" ,
122
- "multi_modal_data" : {
123
- "audio" : (chunk , sr ),
124
- },
125
- },
126
- "decoder_prompt" :
127
- model_cls .get_decoder_prompt (lang , self .task_type ,
128
- request .prompt )
129
- }
130
- prompts .append (cast (PromptType , prompt ))
110
+ # The model has control over the construction, as long as it
111
+ # returns a valid PromptType.
112
+ prompt = self .model_cls .get_generation_prompt (
113
+ audio = chunk ,
114
+ stt_config = self .asr_config ,
115
+ language = lang ,
116
+ task_type = self .task_type ,
117
+ request_prompt = request .prompt )
118
+ prompts .append (prompt )
131
119
return prompts , duration
132
120
133
121
async def _create_speech_to_text (
@@ -196,7 +184,8 @@ async def _create_speech_to_text(
196
184
197
185
self ._log_inputs (
198
186
request_id ,
199
- prompts [0 ]['decoder_prompt' ], # type: ignore
187
+ # It will not display special tokens like <|startoftranscript|>
188
+ request .prompt ,
200
189
params = sampling_params ,
201
190
lora_request = None ,
202
191
prompt_adapter_request = None )
@@ -261,17 +250,11 @@ async def _speech_to_text_stream_generator(
261
250
async for res in result_generator :
262
251
# On first result.
263
252
if res .prompt_token_ids is not None :
264
- # Do not account the 4-tokens `<|startoftranscript|>..`
265
- # Could be negative when language token
266
- # is not specified.
267
- num_prompt_tokens = max (
268
- len (res .prompt_token_ids ) - 4 , 0 )
269
- # NOTE(NickLucche) user can't pass encoder
270
- # prompts directly at least not to Whisper.
271
- # One indicator of the encoder amount of processing
272
- # is the log-mel spectogram length.
273
- num_prompt_tokens += ceil (
274
- audio_duration_s * self .model_sr / self .hop_length )
253
+ num_prompt_tokens = len (res .prompt_token_ids )
254
+ if audio_tokens := self .model_cls .get_num_audio_tokens (
255
+ audio_duration_s , self .asr_config ,
256
+ self .model_config ):
257
+ num_prompt_tokens += audio_tokens
275
258
276
259
# We need to do it here, because if there are exceptions in
277
260
# the result_generator, it needs to be sent as the FIRST
@@ -347,8 +330,8 @@ async def _speech_to_text_stream_generator(
347
330
348
331
def _split_audio (self , audio_data : np .ndarray ,
349
332
sample_rate : int ) -> list [np .ndarray ]:
350
- chunk_size = sample_rate * self .max_audio_clip_s
351
- overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
333
+ chunk_size = sample_rate * self .asr_config . max_audio_clip_s
334
+ overlap_size = sample_rate * self . asr_config . overlap_chunk_second
352
335
chunks = []
353
336
i = 0
354
337
while i < audio_data .shape [- 1 ]:
@@ -384,10 +367,10 @@ def _find_split_point(self, wav: np.ndarray, start_idx: int,
384
367
# Calculate RMS energy in small windows
385
368
min_energy = math .inf
386
369
quietest_idx = 0
387
- for i in range ( 0 ,
388
- len ( segment ) - MIN_ENERGY_WINDOW_SIZE ,
389
- MIN_ENERGY_WINDOW_SIZE ):
390
- window = segment [i :i + MIN_ENERGY_WINDOW_SIZE ]
370
+ min_energy_window = self . asr_config . min_energy_split_window_size
371
+ assert min_energy_window is not None
372
+ for i in range ( 0 , len ( segment ) - min_energy_window , min_energy_window ):
373
+ window = segment [i :i + min_energy_window ]
391
374
energy = (window ** 2 ).mean ()** 0.5
392
375
if energy < min_energy :
393
376
quietest_idx = i + start_idx
0 commit comments