Skip to content

Commit 16f557d

Browse files
committed
v1: Add Whisper model support (encoder-decoder)
This brings Whisper support to V1 to close one of the remaining feature gaps with V0. Most of the changes apply to encoder-decoder models generally, though Whisper is the only one explicitly tested and is the only encoder-decoder model updated to support V1. **Whisper Model Implementation:** - Remove SupportsV0Only interface constraint to enable V1 compatibility - Update get_multimodal_embeddings() to return list format required by V1 **Flash Attention Backend:** - Add encoder attention metadata fields (encoder_seq_start_loc, max_encoder_seq_len, cross_slot_mapping) - Implement encoder self-attention support without using KV cache - Add cross-attention support for encoder-decoder models with proper KV cache handling **KV Cache Manager:** - Introduce CrossAttentionManager for handling cross-attention KV cache in encoder-decoder models - Add CrossAttentionSpec for cross-attention cache specification with encoder-based sizing - Implement allocate_slots_for_cross_attn() for static encoder-length-based allocation - Add cross-attention block allocation logic separate from decoder token growth **Scheduler:** - Disable prefix caching for encoder-decoder models - Implement cross-attention block allocation during request scheduling - Add cross-attention block tracking in state management **GPU Model Runner:** - Add encoder input extraction for audio features processing - Implement encoder attention metadata building for both self-attention and cross-attention - Add cross-attention KV cache group handling with proper slot mapping - Modify input batch creation to accommodate encoder sequence lengths - Add encoder input processing in forward pass with proper device/dtype handling - Update profiling and memory management for encoder-decoder models The implementation maintains backward compatibility while adding comprehensive encoder-decoder support, with particular focus on Whisper's audio processing pipeline and cross-attention mechanisms between encoder and decoder. Related to: - V0 deprecation: #18571 - 2025 Q3 roadmap: #20336 Signed-off-by: Russell Bryant <[email protected]>
1 parent b9a21e9 commit 16f557d

File tree

13 files changed

+651
-87
lines changed

13 files changed

+651
-87
lines changed

vllm/attention/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"AttentionMetadata",
1515
"AttentionType",
1616
"AttentionMetadataBuilder",
17-
"Attention",
1817
"AttentionState",
1918
"get_attn_backend",
2019
]

vllm/inputs/preprocess.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -869,9 +869,6 @@ def preprocess(
869869
) -> ProcessorInputs:
870870
"""Preprocess the input prompt."""
871871
if self.model_config.is_encoder_decoder:
872-
assert not return_mm_hashes, (
873-
"Multimodal hashes for encoder-decoder models should not be ",
874-
"returned until they are supported on vLLM V1.")
875872
# Encoder-decoder model requires special mapping of
876873
# input prompts to encoder & decoder
877874
return self._process_encoder_decoder_prompt(
@@ -903,9 +900,6 @@ async def preprocess_async(
903900
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
904901
"""
905902
if self.model_config.is_encoder_decoder:
906-
assert not return_mm_hashes, (
907-
"Multimodal hashes for encoder-decoder models should not be ",
908-
"returned until they are supported on vLLM V1.")
909903
# Encoder-decoder model requires special mapping of
910904
# input prompts to encoder & decoder
911905
return await self._process_encoder_decoder_prompt_async(prompt)

vllm/model_executor/models/whisper.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from vllm.transformers_utils.processor import cached_get_processor
4343

4444
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
45-
SupportsTranscription, SupportsV0Only)
45+
SupportsTranscription)
4646
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
4747
make_layers)
4848

@@ -790,7 +790,7 @@ def _get_prompt_updates(
790790
info=WhisperProcessingInfo,
791791
dummy_inputs=WhisperDummyInputsBuilder)
792792
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
793-
SupportsMultiModal, SupportsV0Only):
793+
SupportsMultiModal):
794794
packed_modules_mapping = {
795795
"self_attn.qkv_proj": [
796796
"self_attn.q_proj",
@@ -916,10 +916,9 @@ def get_language_model(self) -> torch.nn.Module:
916916

917917
def get_multimodal_embeddings(self,
918918
**kwargs: object) -> MultiModalEmbeddings:
919-
# TODO: This method does not obey the interface for SupportsMultiModal.
920-
# Refactor this once encoder/decoder support is implemented in V1.
919+
# Required as part of SupportsMultiModal interface.
921920
audio_input = self._parse_and_validate_audio_input(**kwargs)
922-
return self.model.get_encoder_outputs(audio_input["input_features"])
921+
return [self.model.get_encoder_outputs(audio_input["input_features"])]
923922

924923
def get_input_embeddings(
925924
self,

vllm/v1/attention/backends/flash_attn.py

Lines changed: 146 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,16 @@ class FlashAttentionMetadata:
130130
prefix_scheduler_metadata: Optional[torch.Tensor] = None
131131
max_num_splits: int = 0
132132

133+
# Begin encoder attn & enc/dec cross-attn fields...
134+
135+
# (batch_size + 1,). The cumulative sequence lengths of the encoder
136+
# sequences in the batch, used to index into sequence. E.g., if the sequence
137+
# length is [4, 6], it is [0, 4, 10].
138+
encoder_seq_start_loc: Optional[torch.Tensor] = None
139+
# Maximum sequence length among encoder sequences
140+
max_encoder_seq_len: Optional[int] = None
141+
cross_slot_mapping: Optional[torch.Tensor] = None
142+
133143
# for local attention
134144
@dataclass
135145
class LocalAttentionMetadata:
@@ -142,6 +152,14 @@ class LocalAttentionMetadata:
142152

143153
local_attn_metadata: Optional[LocalAttentionMetadata] = None
144154

155+
@property
156+
def is_all_encoder_attn_metadata_set(self) -> bool:
157+
"""
158+
All attention metadata required for encoder attention is set.
159+
"""
160+
return (self.encoder_seq_start_loc is not None
161+
and self.max_encoder_seq_len is not None)
162+
145163

146164
def _get_sliding_window_configs(
147165
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
@@ -219,7 +237,13 @@ def build(self,
219237
num_reqs = common_attn_metadata.num_reqs
220238
num_actual_tokens = common_attn_metadata.num_actual_tokens
221239
max_query_len = common_attn_metadata.max_query_len
222-
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
240+
241+
if (common_attn_metadata.cross_slot_mapping is not None
242+
and common_attn_metadata.max_encoder_seq_len is not None):
243+
# ENCODER_DECODER cross-attention
244+
max_seq_len = common_attn_metadata.max_encoder_seq_len
245+
else:
246+
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
223247
query_start_loc = common_attn_metadata.query_start_loc
224248
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
225249
seq_lens = common_attn_metadata.seq_lens
@@ -374,6 +398,10 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
374398
local_attn_metadata=local_attn_metadata,
375399
prefix_scheduler_metadata=prefix_scheduler_metadata,
376400
max_num_splits=max_num_splits,
401+
# Encoder/cross-attention fields
402+
encoder_seq_start_loc=common_attn_metadata.encoder_seq_start_loc,
403+
max_encoder_seq_len=common_attn_metadata.max_encoder_seq_len,
404+
cross_slot_mapping=common_attn_metadata.cross_slot_mapping,
377405
)
378406
return attn_metadata
379407

@@ -428,18 +456,32 @@ def __init__(
428456

429457
FlashAttentionBackend.validate_head_size(head_size)
430458

431-
if attn_type != AttentionType.DECODER:
432-
raise NotImplementedError("Encoder self-attention and "
433-
"encoder/decoder cross-attention "
434-
"are not implemented for "
435-
"FlashAttentionImpl")
459+
self.attn_type = attn_type
436460
self.use_irope = use_irope
437461
self.vllm_flash_attn_version = get_flash_attn_version()
438462
if is_quantized_kv_cache(self.kv_cache_dtype) \
439463
and not flash_attn_supports_fp8():
440464
raise NotImplementedError(
441465
"FlashAttention does not support fp8 kv-cache on this device.")
442466

467+
@staticmethod
468+
def _get_causal_option(attn_type: str) -> bool:
469+
"""
470+
Determine whether the given attention type is suitable for causal
471+
attention mechanisms.
472+
473+
Args:
474+
attn_type (AttentionType): The type of attention being evaluated
475+
476+
Returns:
477+
bool: Returns `True` if the attention type is suitable for causal
478+
attention (i.e., not encoder, encoder-only, or encoder-decoder),
479+
otherwise returns `False`.
480+
"""
481+
return not (attn_type == AttentionType.ENCODER
482+
or attn_type == AttentionType.ENCODER_ONLY
483+
or attn_type == AttentionType.ENCODER_DECODER)
484+
443485
def forward(
444486
self,
445487
layer: torch.nn.Module,
@@ -476,6 +518,14 @@ def forward(
476518
# Profiling run.
477519
return output
478520

521+
# Validate attention metadata based on attention type
522+
attn_type = self.attn_type
523+
if (attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_DECODER,
524+
AttentionType.ENCODER_ONLY)
525+
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
526+
raise AttributeError("Encoder attention requires setting "
527+
"encoder metadata attributes.")
528+
479529
# IMPORTANT!
480530
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
481531
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
@@ -486,22 +536,40 @@ def forward(
486536
# performance to make sure it does not introduce any overhead.
487537

488538
num_actual_tokens = attn_metadata.num_actual_tokens
539+
540+
# Handle encoder attention differently - no KV cache needed
541+
if attn_type == AttentionType.ENCODER:
542+
# For encoder attention,
543+
# we use direct Q, K, V tensors without caching
544+
return self._forward_encoder_attention(query[:num_actual_tokens],
545+
key[:num_actual_tokens],
546+
value[:num_actual_tokens],
547+
output[:num_actual_tokens],
548+
attn_metadata, layer)
549+
550+
# For decoder and cross-attention, use KV cache as before
489551
key_cache, value_cache = kv_cache.unbind(0)
490552

491-
if self.kv_sharing_target_layer_name is None:
553+
if (self.kv_sharing_target_layer_name is None and (key is not None)
554+
and (value is not None)):
492555
# Reshape the input keys and values and store them in the cache.
493556
# Skip this if sharing KV cache with an earlier attention layer.
494557
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
495558
# not padded. However, we don't need to do key[:num_actual_tokens]
496559
# and value[:num_actual_tokens] because the reshape_and_cache_flash
497560
# op uses the slot_mapping's shape to determine the number of
498561
# actual tokens.
562+
if attn_type == AttentionType.ENCODER_DECODER:
563+
updated_slot_mapping = attn_metadata.cross_slot_mapping
564+
else:
565+
updated_slot_mapping = attn_metadata.slot_mapping
566+
499567
reshape_and_cache_flash(
500568
key,
501569
value,
502570
key_cache,
503571
value_cache,
504-
attn_metadata.slot_mapping,
572+
updated_slot_mapping,
505573
self.kv_cache_dtype,
506574
layer._k_scale,
507575
layer._v_scale,
@@ -539,7 +607,7 @@ def forward(
539607
block_table = attn_metadata.block_table
540608
scheduler_metadata = attn_metadata.scheduler_metadata
541609

542-
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
610+
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
543611

544612
flash_attn_varlen_func(
545613
q=query[:num_actual_tokens],
@@ -551,7 +619,7 @@ def forward(
551619
seqused_k=seqused_k,
552620
max_seqlen_k=max_seqlen_k,
553621
softmax_scale=self.scale,
554-
causal=True,
622+
causal=FlashAttentionImpl._get_causal_option(attn_type),
555623
alibi_slopes=self.alibi_slopes,
556624
window_size=self.sliding_window,
557625
block_table=block_table,
@@ -565,33 +633,78 @@ def forward(
565633
)
566634
return output
567635

568-
assert not use_local_attn, (
569-
"Cascade attention does not support local attention.")
570-
# Cascade attention (rare case).
571-
cascade_attention(
572-
output[:num_actual_tokens],
573-
query[:num_actual_tokens],
574-
key_cache,
575-
value_cache,
576-
cu_query_lens=attn_metadata.query_start_loc,
577-
max_query_len=attn_metadata.max_query_len,
578-
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
579-
prefix_kv_lens=attn_metadata.prefix_kv_lens,
580-
suffix_kv_lens=attn_metadata.suffix_kv_lens,
581-
max_kv_len=attn_metadata.max_seq_len,
636+
def _forward_encoder_attention(
637+
self,
638+
query: torch.Tensor,
639+
key: torch.Tensor,
640+
value: torch.Tensor,
641+
output: torch.Tensor,
642+
attn_metadata: FlashAttentionMetadata,
643+
layer: torch.nn.Module,
644+
) -> torch.Tensor:
645+
"""Forward pass for encoder attention without KV cache.
646+
647+
Args:
648+
query: shape = [num_encoder_tokens, num_heads, head_size]
649+
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
650+
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
651+
output: shape = [num_encoder_tokens, num_heads, head_size]
652+
attn_metadata: Encoder attention metadata
653+
layer: The attention layer
654+
"""
655+
# For encoder attention, process FP8 quantization if needed
656+
if self.kv_cache_dtype.startswith("fp8"):
657+
num_tokens, num_heads, head_size = query.shape
658+
query, _ = ops.scaled_fp8_quant(
659+
query.reshape(
660+
(num_tokens, num_heads * head_size)).contiguous(),
661+
layer._q_scale)
662+
query = query.reshape((num_tokens, num_heads, head_size))
663+
664+
num_kv_tokens, num_kv_heads, head_size = key.shape
665+
key, _ = ops.scaled_fp8_quant(
666+
key.reshape(
667+
(num_kv_tokens, num_kv_heads * head_size)).contiguous(),
668+
layer._k_scale)
669+
key = key.reshape((num_kv_tokens, num_kv_heads, head_size))
670+
671+
value, _ = ops.scaled_fp8_quant(
672+
value.reshape(
673+
(num_kv_tokens, num_kv_heads * head_size)).contiguous(),
674+
layer._v_scale)
675+
value = value.reshape((num_kv_tokens, num_kv_heads, head_size))
676+
677+
# Use encoder-specific metadata for sequence information
678+
cu_seqlens_q = attn_metadata.encoder_seq_start_loc
679+
cu_seqlens_k = attn_metadata.encoder_seq_start_loc
680+
max_seqlen_q = attn_metadata.max_encoder_seq_len
681+
max_seqlen_k = attn_metadata.max_encoder_seq_len
682+
683+
descale_shape = (
684+
cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr]
685+
self.num_kv_heads)
686+
687+
# Call flash attention directly on Q, K, V tensors
688+
flash_attn_varlen_func(
689+
q=query,
690+
k=key,
691+
v=value,
692+
out=output,
693+
cu_seqlens_q=cu_seqlens_q,
694+
cu_seqlens_k=cu_seqlens_k,
695+
max_seqlen_q=max_seqlen_q,
696+
max_seqlen_k=max_seqlen_k,
582697
softmax_scale=self.scale,
698+
causal=False, # Encoder attention is bidirectional
583699
alibi_slopes=self.alibi_slopes,
584-
sliding_window=self.sliding_window,
585-
logits_soft_cap=self.logits_soft_cap,
586-
block_table=attn_metadata.block_table,
587-
common_prefix_len=attn_metadata.common_prefix_len,
700+
window_size=self.sliding_window,
701+
softcap=self.logits_soft_cap,
588702
fa_version=self.vllm_flash_attn_version,
589-
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
590-
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
591-
q_descale=layer._q_scale,
592-
k_descale=layer._k_scale,
593-
v_descale=layer._v_scale,
703+
q_descale=layer._q_scale.expand(descale_shape),
704+
k_descale=layer._k_scale.expand(descale_shape),
705+
v_descale=layer._v_scale.expand(descale_shape),
594706
)
707+
595708
return output
596709

597710

vllm/v1/attention/backends/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ class CommonAttentionMetadata:
5959
block_table_tensor: torch.Tensor
6060
slot_mapping: torch.Tensor
6161

62+
# Encoder/cross-attention specific fields (optional)
63+
encoder_seq_start_loc: Optional[torch.Tensor] = None
64+
"""(batch_size + 1,), cumulative encoder sequence lengths"""
65+
max_encoder_seq_len: Optional[int] = None
66+
"""Maximum encoder sequence length in batch"""
67+
cross_slot_mapping: Optional[torch.Tensor] = None
68+
"""Slot mapping for cross-attention KV cache"""
69+
6270
def __post_init__(self):
6371
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
6472
# mode.

0 commit comments

Comments
 (0)