@@ -130,6 +130,16 @@ class FlashAttentionMetadata:
130
130
prefix_scheduler_metadata : Optional [torch .Tensor ] = None
131
131
max_num_splits : int = 0
132
132
133
+ # Begin encoder attn & enc/dec cross-attn fields...
134
+
135
+ # (batch_size + 1,). The cumulative sequence lengths of the encoder
136
+ # sequences in the batch, used to index into sequence. E.g., if the sequence
137
+ # length is [4, 6], it is [0, 4, 10].
138
+ encoder_seq_start_loc : Optional [torch .Tensor ] = None
139
+ # Maximum sequence length among encoder sequences
140
+ max_encoder_seq_len : Optional [int ] = None
141
+ cross_slot_mapping : Optional [torch .Tensor ] = None
142
+
133
143
# for local attention
134
144
@dataclass
135
145
class LocalAttentionMetadata :
@@ -142,6 +152,14 @@ class LocalAttentionMetadata:
142
152
143
153
local_attn_metadata : Optional [LocalAttentionMetadata ] = None
144
154
155
+ @property
156
+ def is_all_encoder_attn_metadata_set (self ) -> bool :
157
+ """
158
+ All attention metadata required for encoder attention is set.
159
+ """
160
+ return (self .encoder_seq_start_loc is not None
161
+ and self .max_encoder_seq_len is not None )
162
+
145
163
146
164
def _get_sliding_window_configs (
147
165
vllm_config : VllmConfig ) -> set [Optional [tuple [int , int ]]]:
@@ -219,7 +237,13 @@ def build(self,
219
237
num_reqs = common_attn_metadata .num_reqs
220
238
num_actual_tokens = common_attn_metadata .num_actual_tokens
221
239
max_query_len = common_attn_metadata .max_query_len
222
- max_seq_len = int (common_attn_metadata .seq_lens_cpu .max ())
240
+
241
+ if (common_attn_metadata .cross_slot_mapping is not None
242
+ and common_attn_metadata .max_encoder_seq_len is not None ):
243
+ # ENCODER_DECODER cross-attention
244
+ max_seq_len = common_attn_metadata .max_encoder_seq_len
245
+ else :
246
+ max_seq_len = int (common_attn_metadata .seq_lens_cpu .max ())
223
247
query_start_loc = common_attn_metadata .query_start_loc
224
248
query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
225
249
seq_lens = common_attn_metadata .seq_lens
@@ -374,6 +398,10 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
374
398
local_attn_metadata = local_attn_metadata ,
375
399
prefix_scheduler_metadata = prefix_scheduler_metadata ,
376
400
max_num_splits = max_num_splits ,
401
+ # Encoder/cross-attention fields
402
+ encoder_seq_start_loc = common_attn_metadata .encoder_seq_start_loc ,
403
+ max_encoder_seq_len = common_attn_metadata .max_encoder_seq_len ,
404
+ cross_slot_mapping = common_attn_metadata .cross_slot_mapping ,
377
405
)
378
406
return attn_metadata
379
407
@@ -428,18 +456,32 @@ def __init__(
428
456
429
457
FlashAttentionBackend .validate_head_size (head_size )
430
458
431
- if attn_type != AttentionType .DECODER :
432
- raise NotImplementedError ("Encoder self-attention and "
433
- "encoder/decoder cross-attention "
434
- "are not implemented for "
435
- "FlashAttentionImpl" )
459
+ self .attn_type = attn_type
436
460
self .use_irope = use_irope
437
461
self .vllm_flash_attn_version = get_flash_attn_version ()
438
462
if is_quantized_kv_cache (self .kv_cache_dtype ) \
439
463
and not flash_attn_supports_fp8 ():
440
464
raise NotImplementedError (
441
465
"FlashAttention does not support fp8 kv-cache on this device." )
442
466
467
+ @staticmethod
468
+ def _get_causal_option (attn_type : str ) -> bool :
469
+ """
470
+ Determine whether the given attention type is suitable for causal
471
+ attention mechanisms.
472
+
473
+ Args:
474
+ attn_type (AttentionType): The type of attention being evaluated
475
+
476
+ Returns:
477
+ bool: Returns `True` if the attention type is suitable for causal
478
+ attention (i.e., not encoder, encoder-only, or encoder-decoder),
479
+ otherwise returns `False`.
480
+ """
481
+ return not (attn_type == AttentionType .ENCODER
482
+ or attn_type == AttentionType .ENCODER_ONLY
483
+ or attn_type == AttentionType .ENCODER_DECODER )
484
+
443
485
def forward (
444
486
self ,
445
487
layer : torch .nn .Module ,
@@ -476,6 +518,14 @@ def forward(
476
518
# Profiling run.
477
519
return output
478
520
521
+ # Validate attention metadata based on attention type
522
+ attn_type = self .attn_type
523
+ if (attn_type in (AttentionType .ENCODER , AttentionType .ENCODER_DECODER ,
524
+ AttentionType .ENCODER_ONLY )
525
+ and (not attn_metadata .is_all_encoder_attn_metadata_set )):
526
+ raise AttributeError ("Encoder attention requires setting "
527
+ "encoder metadata attributes." )
528
+
479
529
# IMPORTANT!
480
530
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
481
531
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
@@ -486,22 +536,40 @@ def forward(
486
536
# performance to make sure it does not introduce any overhead.
487
537
488
538
num_actual_tokens = attn_metadata .num_actual_tokens
539
+
540
+ # Handle encoder attention differently - no KV cache needed
541
+ if attn_type == AttentionType .ENCODER :
542
+ # For encoder attention,
543
+ # we use direct Q, K, V tensors without caching
544
+ return self ._forward_encoder_attention (query [:num_actual_tokens ],
545
+ key [:num_actual_tokens ],
546
+ value [:num_actual_tokens ],
547
+ output [:num_actual_tokens ],
548
+ attn_metadata , layer )
549
+
550
+ # For decoder and cross-attention, use KV cache as before
489
551
key_cache , value_cache = kv_cache .unbind (0 )
490
552
491
- if self .kv_sharing_target_layer_name is None :
553
+ if (self .kv_sharing_target_layer_name is None and (key is not None )
554
+ and (value is not None )):
492
555
# Reshape the input keys and values and store them in the cache.
493
556
# Skip this if sharing KV cache with an earlier attention layer.
494
557
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
495
558
# not padded. However, we don't need to do key[:num_actual_tokens]
496
559
# and value[:num_actual_tokens] because the reshape_and_cache_flash
497
560
# op uses the slot_mapping's shape to determine the number of
498
561
# actual tokens.
562
+ if attn_type == AttentionType .ENCODER_DECODER :
563
+ updated_slot_mapping = attn_metadata .cross_slot_mapping
564
+ else :
565
+ updated_slot_mapping = attn_metadata .slot_mapping
566
+
499
567
reshape_and_cache_flash (
500
568
key ,
501
569
value ,
502
570
key_cache ,
503
571
value_cache ,
504
- attn_metadata . slot_mapping ,
572
+ updated_slot_mapping ,
505
573
self .kv_cache_dtype ,
506
574
layer ._k_scale ,
507
575
layer ._v_scale ,
@@ -539,7 +607,7 @@ def forward(
539
607
block_table = attn_metadata .block_table
540
608
scheduler_metadata = attn_metadata .scheduler_metadata
541
609
542
- descale_shape = (cu_seqlens_q .shape [0 ] - 1 , key . shape [ 1 ] )
610
+ descale_shape = (cu_seqlens_q .shape [0 ] - 1 , self . num_kv_heads )
543
611
544
612
flash_attn_varlen_func (
545
613
q = query [:num_actual_tokens ],
@@ -551,7 +619,7 @@ def forward(
551
619
seqused_k = seqused_k ,
552
620
max_seqlen_k = max_seqlen_k ,
553
621
softmax_scale = self .scale ,
554
- causal = True ,
622
+ causal = FlashAttentionImpl . _get_causal_option ( attn_type ) ,
555
623
alibi_slopes = self .alibi_slopes ,
556
624
window_size = self .sliding_window ,
557
625
block_table = block_table ,
@@ -565,33 +633,78 @@ def forward(
565
633
)
566
634
return output
567
635
568
- assert not use_local_attn , (
569
- "Cascade attention does not support local attention." )
570
- # Cascade attention (rare case).
571
- cascade_attention (
572
- output [:num_actual_tokens ],
573
- query [:num_actual_tokens ],
574
- key_cache ,
575
- value_cache ,
576
- cu_query_lens = attn_metadata .query_start_loc ,
577
- max_query_len = attn_metadata .max_query_len ,
578
- cu_prefix_query_lens = attn_metadata .cu_prefix_query_lens ,
579
- prefix_kv_lens = attn_metadata .prefix_kv_lens ,
580
- suffix_kv_lens = attn_metadata .suffix_kv_lens ,
581
- max_kv_len = attn_metadata .max_seq_len ,
636
+ def _forward_encoder_attention (
637
+ self ,
638
+ query : torch .Tensor ,
639
+ key : torch .Tensor ,
640
+ value : torch .Tensor ,
641
+ output : torch .Tensor ,
642
+ attn_metadata : FlashAttentionMetadata ,
643
+ layer : torch .nn .Module ,
644
+ ) -> torch .Tensor :
645
+ """Forward pass for encoder attention without KV cache.
646
+
647
+ Args:
648
+ query: shape = [num_encoder_tokens, num_heads, head_size]
649
+ key: shape = [num_encoder_tokens, num_kv_heads, head_size]
650
+ value: shape = [num_encoder_tokens, num_kv_heads, head_size]
651
+ output: shape = [num_encoder_tokens, num_heads, head_size]
652
+ attn_metadata: Encoder attention metadata
653
+ layer: The attention layer
654
+ """
655
+ # For encoder attention, process FP8 quantization if needed
656
+ if self .kv_cache_dtype .startswith ("fp8" ):
657
+ num_tokens , num_heads , head_size = query .shape
658
+ query , _ = ops .scaled_fp8_quant (
659
+ query .reshape (
660
+ (num_tokens , num_heads * head_size )).contiguous (),
661
+ layer ._q_scale )
662
+ query = query .reshape ((num_tokens , num_heads , head_size ))
663
+
664
+ num_kv_tokens , num_kv_heads , head_size = key .shape
665
+ key , _ = ops .scaled_fp8_quant (
666
+ key .reshape (
667
+ (num_kv_tokens , num_kv_heads * head_size )).contiguous (),
668
+ layer ._k_scale )
669
+ key = key .reshape ((num_kv_tokens , num_kv_heads , head_size ))
670
+
671
+ value , _ = ops .scaled_fp8_quant (
672
+ value .reshape (
673
+ (num_kv_tokens , num_kv_heads * head_size )).contiguous (),
674
+ layer ._v_scale )
675
+ value = value .reshape ((num_kv_tokens , num_kv_heads , head_size ))
676
+
677
+ # Use encoder-specific metadata for sequence information
678
+ cu_seqlens_q = attn_metadata .encoder_seq_start_loc
679
+ cu_seqlens_k = attn_metadata .encoder_seq_start_loc
680
+ max_seqlen_q = attn_metadata .max_encoder_seq_len
681
+ max_seqlen_k = attn_metadata .max_encoder_seq_len
682
+
683
+ descale_shape = (
684
+ cu_seqlens_q .shape [0 ] - 1 , # type: ignore[union-attr]
685
+ self .num_kv_heads )
686
+
687
+ # Call flash attention directly on Q, K, V tensors
688
+ flash_attn_varlen_func (
689
+ q = query ,
690
+ k = key ,
691
+ v = value ,
692
+ out = output ,
693
+ cu_seqlens_q = cu_seqlens_q ,
694
+ cu_seqlens_k = cu_seqlens_k ,
695
+ max_seqlen_q = max_seqlen_q ,
696
+ max_seqlen_k = max_seqlen_k ,
582
697
softmax_scale = self .scale ,
698
+ causal = False , # Encoder attention is bidirectional
583
699
alibi_slopes = self .alibi_slopes ,
584
- sliding_window = self .sliding_window ,
585
- logits_soft_cap = self .logits_soft_cap ,
586
- block_table = attn_metadata .block_table ,
587
- common_prefix_len = attn_metadata .common_prefix_len ,
700
+ window_size = self .sliding_window ,
701
+ softcap = self .logits_soft_cap ,
588
702
fa_version = self .vllm_flash_attn_version ,
589
- prefix_scheduler_metadata = attn_metadata .prefix_scheduler_metadata ,
590
- suffix_scheduler_metadata = attn_metadata .scheduler_metadata ,
591
- q_descale = layer ._q_scale ,
592
- k_descale = layer ._k_scale ,
593
- v_descale = layer ._v_scale ,
703
+ q_descale = layer ._q_scale .expand (descale_shape ),
704
+ k_descale = layer ._k_scale .expand (descale_shape ),
705
+ v_descale = layer ._v_scale .expand (descale_shape ),
594
706
)
707
+
595
708
return output
596
709
597
710
0 commit comments