From ab73b9883e5c4617aa4a813df019d4471313237f Mon Sep 17 00:00:00 2001 From: Yizhou Liu Date: Fri, 25 Apr 2025 16:55:25 +0800 Subject: [PATCH] [Feature] Add support for custom DeepSeek modeling in ACL Graph mode Signed-off-by: Yizhou Liu --- vllm_ascend/attention/mla_v1.py | 128 +++++++++++++++++++++--------- vllm_ascend/models/deepseek_v2.py | 36 +++++++++ vllm_ascend/ops/fused_moe.py | 10 +-- vllm_ascend/platform.py | 3 +- 4 files changed, 132 insertions(+), 45 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 64a5431e9..bfea76993 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -6,6 +6,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.utils import direct_register_custom_op from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) @@ -524,40 +526,48 @@ def forward( kv_cache: torch.Tensor, attn_metadata: M, output: Optional[torch.Tensor] = None, + trace_flag: bool = True, ) -> torch.Tensor: - assert output is not None, "Output tensor must be provided." - if attn_metadata is None: - # Profiling run. - return output + if trace_flag: + torch.ops.vllm.unified_ascend_mla_attention_with_output( + query=hidden_states_or_q_c, + key=k_c_normed, + value=k_pe, + output=output, + layer_name=layer.layer_name) + else: + if attn_metadata is None: + # Profiling run. + return output - num_actual_toks = attn_metadata.num_actual_tokens + num_actual_toks = attn_metadata.num_actual_tokens - # Inputs and outputs may be padded for CUDA graphs - output_padded = output - output = output[:num_actual_toks, ...] - hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] - k_c_normed = k_c_normed[:num_actual_toks, ...] - k_pe = k_pe[:num_actual_toks, ...] + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] - # Restore head dim (for rotary embedding) - k_pe = k_pe.unsqueeze(1) + # Restore head dim (for rotary embedding) + k_pe = k_pe.unsqueeze(1) - assert attn_metadata.num_decodes is not None and \ - attn_metadata.num_prefills is not None and \ - attn_metadata.num_decode_tokens is not None + assert attn_metadata.num_decodes is not None and \ + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 - num_decode_tokens = attn_metadata.num_decode_tokens + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] - decode_k_pe = k_pe[:num_decode_tokens] + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + decode_k_pe = k_pe[:num_decode_tokens] - prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] - prefill_k_pe = k_pe[num_decode_tokens:] - prefill_k_c_normed = k_c_normed[num_decode_tokens:] + prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] if has_decode: assert attn_metadata.decode is not None @@ -569,11 +579,11 @@ def forward( decode_k_pe, max_seq_len=attn_metadata.decode.max_seq_lens) - if has_prefill: - assert attn_metadata.prefill is not None - prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ - .view(-1, self.num_heads, self.qk_head_dim) - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + if has_prefill: + assert attn_metadata.prefill is not None + prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ + .view(-1, self.num_heads, self.qk_head_dim) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( attn_metadata.prefill.input_positions, @@ -591,13 +601,53 @@ def forward( key_cache=kv_cache, slot_indices=attn_metadata.slot_mapping.flatten()) - if has_prefill: - output[num_decode_tokens:] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) - - if has_decode: - output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) - - return output_padded + if has_prefill: + output[num_decode_tokens:] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, + attn_metadata) + + if has_decode: + output[:num_decode_tokens] = self._forward_decode( + decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) + return output_padded + + +def unified_ascend_mla_attention_with_output( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward(self, + query, + key, + value, + kv_cache, + attn_metadata, + output, + trace_flag=False) + return + + +def unified_mla_attention_with_output_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="unified_ascend_mla_attention_with_output", + op_func=unified_ascend_mla_attention_with_output, + mutates_args=["output"], + fake_impl=unified_mla_attention_with_output_fake, + dispatch_key="PrivateUse1", +) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index ef4751977..7d6546bfc 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -34,6 +34,7 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config) from vllm.distributed import (get_dp_group, get_pp_group, @@ -206,6 +207,11 @@ def __init__( self.dp_size = get_dp_group().world_size batch_size = vllm_config.scheduler_config.max_num_seqs self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", '0')) == 1 + additional_config = vllm_config.additional_config + self.enable_graph_mode = False + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) params_dtype = torch.get_default_dtype() self.final_hidden_states = torch.zeros( @@ -213,6 +219,35 @@ def __init__( self.tp_group = get_tp_group().device_group def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.enable_graph_mode: + return self._forward(hidden_states) + else: + return self._forward_eager(hidden_states) + + def _forward_eager(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + + # NOTE(Yizhou): Quite strange that the order of these two operations + # is reversed in the original vLLM code + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + + return final_hidden_states.view(num_tokens, hidden_dim) + + def _forward(self, hidden_states: torch.Tensor) -> torch.Tensor: attn_metadata = get_forward_context().attn_metadata if attn_metadata is None: # for profile run @@ -528,6 +563,7 @@ def forward( return hidden_states, residual +@support_torch_compile class CustomDeepseekV2Model(nn.Module): fall_back_to_pt_during_load = False diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 2c25e0c76..087617840 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -188,7 +188,7 @@ def fused_experts( # assert dtype in [torch.float32, torch.float16, torch.bfloat16 # ], "Only float32, float16, and bfloat16 are supported" - if expert_map is not None: + if True or expert_map is not None: # Generate token indices and flatten token_indices = (torch.arange(num_tokens, device=device, @@ -198,7 +198,7 @@ def fused_experts( # Flatten token-to-expert mappings and map to local experts weights_flat = topk_weights.view(-1) experts_flat = topk_ids.view(-1) - local_experts_flat = expert_map[experts_flat] + local_experts_flat = expert_map[experts_flat] if expert_map is not None else experts_flat # Filter valid token-expert pairs mask = local_experts_flat != -1 @@ -270,7 +270,7 @@ def fused_experts( down_out_list = torch.cat(down_out_list, dim=0) - if expert_map is not None: + if True or expert_map is not None: weighted_down_out = down_out_list * sorted_weights.unsqueeze(1) final_hidden_states = torch.zeros(*original_shape, @@ -614,8 +614,8 @@ def __init__(self, def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - is_prefill: bool, - top_k=None): + is_prefill: bool = True, + top_k: Optional[int] = None): assert self.quant_method is not None if top_k: diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 5d2c8acd5..c39cee447 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -143,7 +143,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "using only ACL Graph mode") compilation_config.use_inductor = False compilation_config.splitting_ops.extend( - ["vllm.unified_ascend_attention_with_output"]) + ["vllm.unified_ascend_attention_with_output", + "vllm.unified_ascend_mla_attention_with_output"]) if vllm_config.additional_config is not None: enable_graph_mode = vllm_config.additional_config.get(