Skip to content

Commit 304dce7

Browse files
[Attention] Clean up iRoPE in V1 (#21188)
Signed-off-by: Lucas Wilkinson <[email protected]> Co-authored-by: Michael Goin <[email protected]>
1 parent 6ece16c commit 304dce7

File tree

9 files changed

+14
-26
lines changed

9 files changed

+14
-26
lines changed

vllm/attention/layer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,13 @@ def __init__(
137137
self.num_kv_heads = num_kv_heads
138138
self.sliding_window = sliding_window
139139

140+
# For v1 we have backend agnostic iRoPE (local chunked attention)
141+
# we have to store the flag on the layer so gpu model runner can
142+
# set KVSpec appropriately (and pop it so it doesnt get passed to
143+
# the backends)
144+
if envs.VLLM_USE_V1:
145+
self.use_irope = extra_impl_args.pop("use_irope", False)
146+
140147
quant_method = quant_config.get_quant_method(
141148
self, prefix=prefix) if quant_config else None
142149
if quant_method is not None and not isinstance(

vllm/v1/attention/backends/cpu_attn.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -446,17 +446,12 @@ def __init__(
446446
logits_soft_cap: Optional[float] = None,
447447
attn_type: str = AttentionType.DECODER,
448448
kv_sharing_target_layer_name: Optional[str] = None,
449-
use_irope: bool = False,
450449
) -> None:
451450
if kv_sharing_target_layer_name is not None:
452451
raise NotImplementedError("KV sharing is not supported in V0.")
453452
if logits_soft_cap is not None:
454453
logger.warning_once("Torch SPDA does not support logits soft cap. "
455454
"Outputs may be slightly off.")
456-
if use_irope:
457-
logger.warning_once(
458-
"Using irope in Torch SPDA is not supported yet, it will fall"
459-
" back to global attention for long context.")
460455
self.paged_attn_impl = _get_paged_attn_impl()
461456
self.num_heads = num_heads
462457
self.head_size = head_size

vllm/v1/attention/backends/flash_attn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,6 @@ def __init__(
352352
logits_soft_cap: Optional[float] = None,
353353
attn_type: AttentionType = AttentionType.DECODER,
354354
kv_sharing_target_layer_name: Optional[str] = None,
355-
use_irope: bool = False,
356355
) -> None:
357356
self.num_heads = num_heads
358357
self.head_size = head_size
@@ -381,7 +380,6 @@ def __init__(
381380
"encoder/decoder cross-attention "
382381
"are not implemented for "
383382
"FlashAttentionImpl")
384-
self.use_irope = use_irope
385383
self.vllm_flash_attn_version = get_flash_attn_version()
386384
if is_quantized_kv_cache(self.kv_cache_dtype) \
387385
and not flash_attn_supports_fp8():

vllm/v1/attention/backends/flashinfer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,6 @@ def __init__(
493493
logits_soft_cap: Optional[float] = None,
494494
attn_type: AttentionType = AttentionType.DECODER,
495495
kv_sharing_target_layer_name: Optional[int] = None,
496-
use_irope: bool = False,
497496
) -> None:
498497
self.num_heads = num_heads
499498
self.head_size = head_size
@@ -509,7 +508,6 @@ def __init__(
509508
self.kv_cache_dtype = kv_cache_dtype
510509
self.logits_soft_cap = logits_soft_cap
511510
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
512-
self.use_irope = use_irope
513511

514512
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
515513

vllm/v1/attention/backends/pallas.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,7 @@ def __init__(
148148
logits_soft_cap: Optional[float] = None,
149149
attn_type: str = AttentionType.DECODER,
150150
kv_sharing_target_layer_name: Optional[int] = None,
151-
use_irope: bool = False,
152151
) -> None:
153-
if use_irope:
154-
logger.warning_once(
155-
"Using irope in Pallas is not supported yet, it will fall back "
156-
"to global attention for long context.")
157152
self.num_heads = num_heads
158153
self.head_size = head_size
159154
self.scale = float(scale)

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,6 @@ def __init__(
337337
logits_soft_cap: Optional[float] = None,
338338
attn_type: AttentionType = AttentionType.DECODER,
339339
kv_sharing_target_layer_name: Optional[int] = None,
340-
use_irope: bool = False,
341340
) -> None:
342341
self.num_heads = num_heads
343342
self.head_size = head_size
@@ -367,7 +366,6 @@ def __init__(
367366
"encoder/decoder cross-attention "
368367
"are not implemented for "
369368
"FlashAttentionImpl")
370-
self.use_irope = use_irope
371369
if is_quantized_kv_cache(self.kv_cache_dtype):
372370
raise NotImplementedError(
373371
"AiterFlashAttention does not support fp8 kv-cache on this "

vllm/v1/attention/backends/triton_attn.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
7272
vllm_config.parallel_config)
7373
self.headdim = model_config.get_head_size()
7474

75-
self.attention_chunk_size = getattr(vllm_config.scheduler_config,
76-
'attention_chunk_size', None)
77-
7875
def build_for_cudagraph_capture(
7976
self, common_attn_metadata: CommonAttentionMetadata
8077
) -> TritonAttentionMetadata:
@@ -208,7 +205,6 @@ def __init__(
208205
logits_soft_cap: Optional[float] = None,
209206
attn_type: AttentionType = AttentionType.DECODER,
210207
kv_sharing_target_layer_name: Optional[int] = None,
211-
use_irope: bool = False,
212208
) -> None:
213209
self.num_heads = num_heads
214210
self.head_size = head_size
@@ -228,8 +224,6 @@ def __init__(
228224
self.logits_soft_cap = logits_soft_cap
229225
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
230226

231-
self.use_irope = use_irope
232-
233227
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
234228

235229
TritonAttentionBackend.validate_head_size(head_size)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,8 +2702,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
27022702
# TODO: Support other attention modules, e.g., cross-attention
27032703
if attn_module.attn_type == AttentionType.DECODER:
27042704
use_local_attention = (self.attention_chunk_size is not None
2705-
and getattr(attn_module.impl,
2706-
"use_irope", False))
2705+
and attn_module.use_irope)
27072706
if attn_module.sliding_window is not None:
27082707
kv_cache_spec[layer_name] = SlidingWindowSpec(
27092708
block_size=block_size,
@@ -2716,13 +2715,13 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
27162715
"attention module can not be with ",
27172716
"both local attention and sliding window")
27182717
elif use_local_attention:
2719-
kv_cache_spec[layer_name] = (ChunkedLocalAttentionSpec(
2718+
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
27202719
block_size=block_size,
27212720
num_kv_heads=attn_module.num_kv_heads,
27222721
head_size=attn_module.head_size,
27232722
dtype=self.kv_cache_dtype,
27242723
attention_chunk_size=self.attention_chunk_size,
2725-
use_mla=use_mla))
2724+
use_mla=use_mla)
27262725
else:
27272726
kv_cache_spec[layer_name] = FullAttentionSpec(
27282727
block_size=block_size,

vllm/v1/worker/tpu_model_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
519519
continue
520520

521521
if attn_module.attn_type == AttentionType.DECODER:
522+
if attn_module.use_irope:
523+
logger.warning_once(
524+
"Using irope in Pallas is not supported yet, it "
525+
"will fall back to global attention for long context.")
522526
if attn_module.sliding_window is not None:
523527
kv_cache_spec[layer_name] = SlidingWindowSpec(
524528
block_size=block_size,

0 commit comments

Comments
 (0)