|
6 | 6 | from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
7 | 7 | AttentionMetadata,
|
8 | 8 | MLAAttentionImpl)
|
| 9 | +from vllm.forward_context import ForwardContext, get_forward_context |
| 10 | +from vllm.utils import direct_register_custom_op |
9 | 11 | from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
10 | 12 | LinearBase, RowParallelLinear,
|
11 | 13 | UnquantizedLinearMethod)
|
@@ -483,76 +485,124 @@ def forward(
|
483 | 485 | kv_cache: torch.Tensor,
|
484 | 486 | attn_metadata: M,
|
485 | 487 | output: Optional[torch.Tensor] = None,
|
| 488 | + trace_flag: bool = True, |
486 | 489 | ) -> torch.Tensor:
|
487 |
| - |
488 | 490 | assert output is not None, "Output tensor must be provided."
|
489 | 491 |
|
490 |
| - if attn_metadata is None: |
491 |
| - # Profiling run. |
492 |
| - return output |
493 |
| - |
494 |
| - num_actual_toks = attn_metadata.num_actual_tokens |
495 |
| - |
496 |
| - # Inputs and outputs may be padded for CUDA graphs |
497 |
| - output_padded = output |
498 |
| - output = output[:num_actual_toks, ...] |
499 |
| - hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] |
500 |
| - k_c_normed = k_c_normed[:num_actual_toks, ...] |
501 |
| - k_pe = k_pe[:num_actual_toks, ...] |
502 |
| - |
503 |
| - # Restore head dim (for rotary embedding) |
504 |
| - k_pe = k_pe.unsqueeze(1) |
505 |
| - |
506 |
| - assert attn_metadata.num_decodes is not None and \ |
507 |
| - attn_metadata.num_prefills is not None and \ |
508 |
| - attn_metadata.num_decode_tokens is not None |
509 |
| - |
510 |
| - has_decode = attn_metadata.num_decodes > 0 |
511 |
| - has_prefill = attn_metadata.num_prefills > 0 |
512 |
| - num_decode_tokens = attn_metadata.num_decode_tokens |
513 |
| - |
514 |
| - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] |
515 |
| - decode_k_pe = k_pe[:num_decode_tokens] |
516 |
| - |
517 |
| - prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] |
518 |
| - prefill_k_pe = k_pe[num_decode_tokens:] |
519 |
| - prefill_k_c_normed = k_c_normed[num_decode_tokens:] |
520 |
| - |
521 |
| - if has_decode: |
522 |
| - assert attn_metadata.decode is not None |
523 |
| - decode_ql_nope, decode_q_pe = \ |
524 |
| - self._q_proj_and_k_up_proj(decode_hs_or_q_c) |
525 |
| - decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( |
526 |
| - attn_metadata.decode.input_positions, decode_q_pe.contiguous(), |
527 |
| - decode_k_pe) |
528 |
| - |
529 |
| - if has_prefill: |
530 |
| - assert attn_metadata.prefill is not None |
531 |
| - prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ |
532 |
| - .view(-1, self.num_heads, self.qk_head_dim) |
533 |
| - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] |
534 |
| - |
535 |
| - prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( |
536 |
| - attn_metadata.prefill.input_positions, |
537 |
| - prefill_q_pe.contiguous(), prefill_k_pe) |
538 |
| - |
539 |
| - if kv_cache.numel() > 0: |
540 |
| - concat_and_cache_mla(k_c_normed, k_pe, kv_cache, |
541 |
| - attn_metadata.slot_mapping.flatten()) |
542 |
| - # TODO: replaced back to ascend ops |
543 |
| - # key = torch.cat([k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe], dim=2) |
544 |
| - # torch_npu._npu_reshape_and_cache_siso( |
545 |
| - # key=key, |
546 |
| - # key_cache=kv_cache, |
547 |
| - # slot_indices=attn_metadata.slot_mapping.flatten()) |
548 |
| - |
549 |
| - if has_prefill: |
550 |
| - output[num_decode_tokens:] = self._forward_prefill( |
551 |
| - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, |
552 |
| - attn_metadata) |
553 |
| - |
554 |
| - if has_decode: |
555 |
| - output[:num_decode_tokens] = self._forward_decode( |
556 |
| - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) |
557 |
| - |
558 |
| - return output_padded |
| 492 | + if trace_flag: |
| 493 | + torch.ops.vllm.unified_ascend_mla_attention_with_output( |
| 494 | + query=hidden_states_or_q_c, |
| 495 | + key=k_c_normed, |
| 496 | + value=k_pe, |
| 497 | + output=output, |
| 498 | + layer_name=layer.layer_name) |
| 499 | + else: |
| 500 | + if attn_metadata is None: |
| 501 | + # Profiling run. |
| 502 | + return output |
| 503 | + |
| 504 | + num_actual_toks = attn_metadata.num_actual_tokens |
| 505 | + |
| 506 | + # Inputs and outputs may be padded for CUDA graphs |
| 507 | + output_padded = output |
| 508 | + output = output[:num_actual_toks, ...] |
| 509 | + hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] |
| 510 | + k_c_normed = k_c_normed[:num_actual_toks, ...] |
| 511 | + k_pe = k_pe[:num_actual_toks, ...] |
| 512 | + |
| 513 | + # Restore head dim (for rotary embedding) |
| 514 | + k_pe = k_pe.unsqueeze(1) |
| 515 | + |
| 516 | + assert attn_metadata.num_decodes is not None and \ |
| 517 | + attn_metadata.num_prefills is not None and \ |
| 518 | + attn_metadata.num_decode_tokens is not None |
| 519 | + |
| 520 | + has_decode = attn_metadata.num_decodes > 0 |
| 521 | + has_prefill = attn_metadata.num_prefills > 0 |
| 522 | + num_decode_tokens = attn_metadata.num_decode_tokens |
| 523 | + |
| 524 | + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] |
| 525 | + decode_k_pe = k_pe[:num_decode_tokens] |
| 526 | + |
| 527 | + prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] |
| 528 | + prefill_k_pe = k_pe[num_decode_tokens:] |
| 529 | + prefill_k_c_normed = k_c_normed[num_decode_tokens:] |
| 530 | + |
| 531 | + if has_decode: |
| 532 | + assert attn_metadata.decode is not None |
| 533 | + decode_ql_nope, decode_q_pe = \ |
| 534 | + self._q_proj_and_k_up_proj(decode_hs_or_q_c) |
| 535 | + decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( |
| 536 | + attn_metadata.decode.input_positions, decode_q_pe.contiguous(), |
| 537 | + decode_k_pe) |
| 538 | + |
| 539 | + if has_prefill: |
| 540 | + assert attn_metadata.prefill is not None |
| 541 | + prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ |
| 542 | + .view(-1, self.num_heads, self.qk_head_dim) |
| 543 | + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] |
| 544 | + |
| 545 | + prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( |
| 546 | + attn_metadata.prefill.input_positions, |
| 547 | + prefill_q_pe.contiguous(), prefill_k_pe) |
| 548 | + |
| 549 | + if kv_cache.numel() > 0: |
| 550 | + concat_and_cache_mla(k_c_normed, k_pe, kv_cache, |
| 551 | + attn_metadata.slot_mapping.flatten()) |
| 552 | + # TODO: replaced back to ascend ops |
| 553 | + # key = torch.cat([k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe], dim=2) |
| 554 | + # torch_npu._npu_reshape_and_cache_siso( |
| 555 | + # key=key, |
| 556 | + # key_cache=kv_cache, |
| 557 | + # slot_indices=attn_metadata.slot_mapping.flatten()) |
| 558 | + |
| 559 | + if has_prefill: |
| 560 | + output[num_decode_tokens:] = self._forward_prefill( |
| 561 | + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, |
| 562 | + attn_metadata) |
| 563 | + |
| 564 | + if has_decode: |
| 565 | + output[:num_decode_tokens] = self._forward_decode( |
| 566 | + decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) |
| 567 | + return output_padded |
| 568 | + |
| 569 | + |
| 570 | +def unified_ascend_mla_attention_with_output( |
| 571 | + query: torch.Tensor, |
| 572 | + key: torch.Tensor, |
| 573 | + value: torch.Tensor, |
| 574 | + output: torch.Tensor, |
| 575 | + layer_name: str, |
| 576 | +) -> None: |
| 577 | + forward_context: ForwardContext = get_forward_context() |
| 578 | + attn_metadata = forward_context.attn_metadata |
| 579 | + self = forward_context.no_compile_layers[layer_name] |
| 580 | + kv_cache = self.kv_cache[forward_context.virtual_engine] |
| 581 | + self.impl.forward(self, |
| 582 | + query, |
| 583 | + key, |
| 584 | + value, |
| 585 | + kv_cache, |
| 586 | + attn_metadata, |
| 587 | + output, |
| 588 | + trace_flag=False) |
| 589 | + return |
| 590 | + |
| 591 | + |
| 592 | +def unified_mla_attention_with_output_fake( |
| 593 | + query: torch.Tensor, |
| 594 | + key: torch.Tensor, |
| 595 | + value: torch.Tensor, |
| 596 | + output: torch.Tensor, |
| 597 | + layer_name: str, |
| 598 | +) -> None: |
| 599 | + return |
| 600 | + |
| 601 | + |
| 602 | +direct_register_custom_op( |
| 603 | + op_name="unified_ascend_mla_attention_with_output", |
| 604 | + op_func=unified_ascend_mla_attention_with_output, |
| 605 | + mutates_args=["output"], |
| 606 | + fake_impl=unified_mla_attention_with_output_fake, |
| 607 | + dispatch_key="PrivateUse1", |
| 608 | +) |
0 commit comments