6
6
import time
7
7
from collections .abc import AsyncGenerator
8
8
from functools import cached_property
9
- from math import ceil
10
9
from typing import Callable , Literal , Optional , TypeVar , Union , cast
11
10
12
- from mistral_common .audio import Audio
13
- from mistral_common .protocol .transcription .request import TranscriptionRequest
14
- from mistral_common .protocol .instruct .messages import AudioChunk
15
11
import numpy as np
16
12
from fastapi import Request
17
13
31
27
from vllm .model_executor .model_loader import get_model_cls
32
28
from vllm .model_executor .models import SupportsTranscription
33
29
from vllm .outputs import RequestOutput
34
- from vllm .transformers_utils .processor import cached_get_processor
35
- from vllm .transformers_utils .tokenizers .mistral import MistralTokenizer
36
30
from vllm .utils import PlaceholderModule
37
31
38
32
try :
48
42
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
49
43
# TODO configurable
50
44
MAX_AUDIO_CLIP_FILESIZE_MB = 25
51
- MAX_AUDIO_CLIP_SECONDS = 30
52
- OVERLAP_CHUNK_SECOND = 1
53
- MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
54
45
55
46
56
47
class OpenAISpeechToText (OpenAIServing ):
@@ -75,88 +66,57 @@ def __init__(
75
66
76
67
self .default_sampling_params = (
77
68
self .model_config .get_diff_sampling_param ())
78
-
79
- self .tokenizer = engine_client .processor .input_preprocessor .tokenizer .tokenizer
80
- if isinstance (self .tokenizer , MistralTokenizer ):
81
- audio_encoder = self .tokenizer .instruct .audio_encoder
82
- self .max_audio_clip_s = None
83
- self .model_sr = audio_encoder .audio_config .sampling_rate
84
- self .hop_length = None
85
- else :
86
- processor = cached_get_processor (model_config .model )
87
- self .audio_encoder = None
88
- self .max_audio_clip_s = processor .feature_extractor .chunk_length \
89
- if hasattr (processor .feature_extractor , 'chunk_length' ) \
90
- else MAX_AUDIO_CLIP_SECONDS
91
- self .model_sr = processor .feature_extractor .sampling_rate
92
- self .hop_length = processor .feature_extractor .hop_length
93
-
94
69
self .task_type = task_type
95
70
71
+ self .asr_config = self .model_cls .get_speech_to_text_config (
72
+ model_config , task_type )
73
+
96
74
if self .default_sampling_params :
97
75
logger .info (
98
76
"Overwriting default completion sampling param with: %s" ,
99
77
self .default_sampling_params )
100
78
101
79
@cached_property
102
- def model_cls (self ):
103
- return get_model_cls (self .model_config )
80
+ def model_cls (self ) -> type [SupportsTranscription ]:
81
+ model_cls = get_model_cls (self .model_config )
82
+ return cast (type [SupportsTranscription ], model_cls )
104
83
105
84
async def _preprocess_speech_to_text (
106
85
self ,
107
86
request : SpeechToTextRequest ,
108
87
audio_data : bytes ,
109
88
) -> tuple [list [PromptType ], float ]:
110
- model_cls = cast (SupportsTranscription , self .model_cls )
111
-
112
89
# Validate request
113
90
# TODO language should be optional and can be guessed.
114
91
# For now we default to en. See
115
92
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
116
93
lang = request .language or "en"
117
- model_cls .validate_language (lang )
94
+ self . model_cls .validate_language (lang )
118
95
119
96
if len (audio_data ) / 1024 ** 2 > MAX_AUDIO_CLIP_FILESIZE_MB :
120
97
raise ValueError ("Maximum file size exceeded." )
121
98
122
- prompts = []
123
- if not isinstance (self .tokenizer , MistralTokenizer ):
124
- with io .BytesIO (audio_data ) as bytes_ :
125
- # NOTE resample to model SR here for efficiency. This is also a
126
- # pre-requisite for chunking, as it assumes Whisper SR.
127
- y , sr = librosa .load (bytes_ , sr = self .model_sr )
128
-
129
- duration = librosa .get_duration (y = y , sr = sr )
130
-
131
- chunks = [y
132
- ] if duration < self .max_audio_clip_s else self ._split_audio (
133
- y , int (sr ))
134
- for chunk in chunks :
135
- prompt = {
136
- "encoder_prompt" : {
137
- "prompt" : "" ,
138
- "multi_modal_data" : {
139
- "audio" : (chunk , sr ),
140
- },
141
- },
142
- "decoder_prompt" :
143
- model_cls .get_decoder_prompt (lang , self .task_type ,
144
- request .prompt )
145
- }
146
- prompts .append (cast (PromptType , prompt ))
147
- else :
148
- oai_request_dict = request .model_dump ()
149
- with io .BytesIO (audio_data ) as bytes_ :
150
- oai_request_dict ["file" ] = bytes_
151
- req = TranscriptionRequest .from_openai (oai_request_dict )
152
-
153
- duration = req .audio .input_audio .duration
154
- tokenized = self .tokenizer .instruct .encode_transcription (req )
155
- audio = (tokenized .audios [0 ].audio_array , self .model_sr )
156
- prompts_dict = {"multi_modal_data" : {"audio" : audio }}
157
- prompts_dict ["prompt_token_ids" ] = tokenized .tokens
158
- prompts = [cast (PromptType , prompts_dict )]
99
+ with io .BytesIO (audio_data ) as bytes_ :
100
+ # NOTE resample to model SR here for efficiency. This is also a
101
+ # pre-requisite for chunking, as it assumes Whisper SR.
102
+ y , sr = librosa .load (bytes_ , sr = self .asr_config .sample_rate )
159
103
104
+ duration = librosa .get_duration (y = y , sr = sr )
105
+ do_split_audio = (self .asr_config .allow_audio_chunking
106
+ and duration > self .asr_config .max_audio_clip_s )
107
+ chunks = [y ] if not do_split_audio else self ._split_audio (y , int (sr ))
108
+ prompts = []
109
+ for chunk in chunks :
110
+ # The model has control over the construction, as long as it
111
+ # returns a valid PromptType.
112
+ prompt = self .model_cls .get_generation_prompt (
113
+ audio = chunk ,
114
+ stt_config = self .asr_config ,
115
+ model_config = self .model_config ,
116
+ language = lang ,
117
+ task_type = self .task_type ,
118
+ request_prompt = request .prompt )
119
+ prompts .append (prompt )
160
120
return prompts , duration
161
121
162
122
async def _create_speech_to_text (
@@ -223,13 +183,13 @@ async def _create_speech_to_text(
223
183
sampling_params = request .to_sampling_params (
224
184
default_max_tokens , self .default_sampling_params )
225
185
226
- if "decoder_prompt" in prompts [ 0 ]:
227
- self . _log_inputs (
228
- request_id ,
229
- prompts [ 0 ][ 'decoder_prompt' ], # type: ignore
230
- params = sampling_params ,
231
- lora_request = None ,
232
- prompt_adapter_request = None )
186
+ self . _log_inputs (
187
+ request_id ,
188
+ # It will not display special tokens like <|startoftranscript|>
189
+ request . prompt ,
190
+ params = sampling_params ,
191
+ lora_request = None ,
192
+ prompt_adapter_request = None )
233
193
234
194
list_result_generator = [
235
195
self .engine_client .generate (
@@ -291,17 +251,11 @@ async def _speech_to_text_stream_generator(
291
251
async for res in result_generator :
292
252
# On first result.
293
253
if res .prompt_token_ids is not None :
294
- # Do not account the 4-tokens `<|startoftranscript|>..`
295
- # Could be negative when language token
296
- # is not specified.
297
- num_prompt_tokens = max (
298
- len (res .prompt_token_ids ) - 4 , 0 )
299
- # NOTE(NickLucche) user can't pass encoder
300
- # prompts directly at least not to Whisper.
301
- # One indicator of the encoder amount of processing
302
- # is the log-mel spectogram length.
303
- num_prompt_tokens += ceil (
304
- audio_duration_s * self .model_sr / self .hop_length )
254
+ num_prompt_tokens = len (res .prompt_token_ids )
255
+ if audio_tokens := self .model_cls .get_num_audio_tokens (
256
+ audio_duration_s , self .asr_config ,
257
+ self .model_config ):
258
+ num_prompt_tokens += audio_tokens
305
259
306
260
# We need to do it here, because if there are exceptions in
307
261
# the result_generator, it needs to be sent as the FIRST
@@ -377,8 +331,8 @@ async def _speech_to_text_stream_generator(
377
331
378
332
def _split_audio (self , audio_data : np .ndarray ,
379
333
sample_rate : int ) -> list [np .ndarray ]:
380
- chunk_size = sample_rate * self .max_audio_clip_s
381
- overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
334
+ chunk_size = sample_rate * self .asr_config . max_audio_clip_s
335
+ overlap_size = sample_rate * self . asr_config . overlap_chunk_second
382
336
chunks = []
383
337
i = 0
384
338
while i < audio_data .shape [- 1 ]:
@@ -414,10 +368,10 @@ def _find_split_point(self, wav: np.ndarray, start_idx: int,
414
368
# Calculate RMS energy in small windows
415
369
min_energy = math .inf
416
370
quietest_idx = 0
417
- for i in range ( 0 ,
418
- len ( segment ) - MIN_ENERGY_WINDOW_SIZE ,
419
- MIN_ENERGY_WINDOW_SIZE ):
420
- window = segment [i :i + MIN_ENERGY_WINDOW_SIZE ]
371
+ min_energy_window = self . asr_config . min_energy_split_window_size
372
+ assert min_energy_window is not None
373
+ for i in range ( 0 , len ( segment ) - min_energy_window , min_energy_window ):
374
+ window = segment [i :i + min_energy_window ]
421
375
energy = (window ** 2 ).mean ()** 0.5
422
376
if energy < min_energy :
423
377
quietest_idx = i + start_idx
0 commit comments