Skip to content

Commit 3b0c106

Browse files
committed
[Feature] Add support for custom DeepSeek modeling in ACL Graph mode
Signed-off-by: Yizhou Liu <[email protected]>
1 parent fa4a5d9 commit 3b0c106

File tree

4 files changed

+163
-76
lines changed

4 files changed

+163
-76
lines changed

vllm_ascend/attention/mla_v1.py

+120-70
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
77
AttentionMetadata,
88
MLAAttentionImpl)
9+
from vllm.forward_context import ForwardContext, get_forward_context
10+
from vllm.utils import direct_register_custom_op
911
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1012
LinearBase, RowParallelLinear,
1113
UnquantizedLinearMethod)
@@ -483,76 +485,124 @@ def forward(
483485
kv_cache: torch.Tensor,
484486
attn_metadata: M,
485487
output: Optional[torch.Tensor] = None,
488+
trace_flag: bool = True,
486489
) -> torch.Tensor:
487-
488490
assert output is not None, "Output tensor must be provided."
489491

490-
if attn_metadata is None:
491-
# Profiling run.
492-
return output
493-
494-
num_actual_toks = attn_metadata.num_actual_tokens
495-
496-
# Inputs and outputs may be padded for CUDA graphs
497-
output_padded = output
498-
output = output[:num_actual_toks, ...]
499-
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
500-
k_c_normed = k_c_normed[:num_actual_toks, ...]
501-
k_pe = k_pe[:num_actual_toks, ...]
502-
503-
# Restore head dim (for rotary embedding)
504-
k_pe = k_pe.unsqueeze(1)
505-
506-
assert attn_metadata.num_decodes is not None and \
507-
attn_metadata.num_prefills is not None and \
508-
attn_metadata.num_decode_tokens is not None
509-
510-
has_decode = attn_metadata.num_decodes > 0
511-
has_prefill = attn_metadata.num_prefills > 0
512-
num_decode_tokens = attn_metadata.num_decode_tokens
513-
514-
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
515-
decode_k_pe = k_pe[:num_decode_tokens]
516-
517-
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
518-
prefill_k_pe = k_pe[num_decode_tokens:]
519-
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
520-
521-
if has_decode:
522-
assert attn_metadata.decode is not None
523-
decode_ql_nope, decode_q_pe = \
524-
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
525-
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
526-
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
527-
decode_k_pe)
528-
529-
if has_prefill:
530-
assert attn_metadata.prefill is not None
531-
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
532-
.view(-1, self.num_heads, self.qk_head_dim)
533-
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
534-
535-
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
536-
attn_metadata.prefill.input_positions,
537-
prefill_q_pe.contiguous(), prefill_k_pe)
538-
539-
if kv_cache.numel() > 0:
540-
concat_and_cache_mla(k_c_normed, k_pe, kv_cache,
541-
attn_metadata.slot_mapping.flatten())
542-
# TODO: replaced back to ascend ops
543-
# key = torch.cat([k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe], dim=2)
544-
# torch_npu._npu_reshape_and_cache_siso(
545-
# key=key,
546-
# key_cache=kv_cache,
547-
# slot_indices=attn_metadata.slot_mapping.flatten())
548-
549-
if has_prefill:
550-
output[num_decode_tokens:] = self._forward_prefill(
551-
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
552-
attn_metadata)
553-
554-
if has_decode:
555-
output[:num_decode_tokens] = self._forward_decode(
556-
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
557-
558-
return output_padded
492+
if trace_flag:
493+
torch.ops.vllm.unified_ascend_mla_attention_with_output(
494+
query=hidden_states_or_q_c,
495+
key=k_c_normed,
496+
value=k_pe,
497+
output=output,
498+
layer_name=layer.layer_name)
499+
else:
500+
if attn_metadata is None:
501+
# Profiling run.
502+
return output
503+
504+
num_actual_toks = attn_metadata.num_actual_tokens
505+
506+
# Inputs and outputs may be padded for CUDA graphs
507+
output_padded = output
508+
output = output[:num_actual_toks, ...]
509+
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
510+
k_c_normed = k_c_normed[:num_actual_toks, ...]
511+
k_pe = k_pe[:num_actual_toks, ...]
512+
513+
# Restore head dim (for rotary embedding)
514+
k_pe = k_pe.unsqueeze(1)
515+
516+
assert attn_metadata.num_decodes is not None and \
517+
attn_metadata.num_prefills is not None and \
518+
attn_metadata.num_decode_tokens is not None
519+
520+
has_decode = attn_metadata.num_decodes > 0
521+
has_prefill = attn_metadata.num_prefills > 0
522+
num_decode_tokens = attn_metadata.num_decode_tokens
523+
524+
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
525+
decode_k_pe = k_pe[:num_decode_tokens]
526+
527+
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
528+
prefill_k_pe = k_pe[num_decode_tokens:]
529+
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
530+
531+
if has_decode:
532+
assert attn_metadata.decode is not None
533+
decode_ql_nope, decode_q_pe = \
534+
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
535+
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
536+
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
537+
decode_k_pe)
538+
539+
if has_prefill:
540+
assert attn_metadata.prefill is not None
541+
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
542+
.view(-1, self.num_heads, self.qk_head_dim)
543+
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
544+
545+
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
546+
attn_metadata.prefill.input_positions,
547+
prefill_q_pe.contiguous(), prefill_k_pe)
548+
549+
if kv_cache.numel() > 0:
550+
concat_and_cache_mla(k_c_normed, k_pe, kv_cache,
551+
attn_metadata.slot_mapping.flatten())
552+
# TODO: replaced back to ascend ops
553+
# key = torch.cat([k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe], dim=2)
554+
# torch_npu._npu_reshape_and_cache_siso(
555+
# key=key,
556+
# key_cache=kv_cache,
557+
# slot_indices=attn_metadata.slot_mapping.flatten())
558+
559+
if has_prefill:
560+
output[num_decode_tokens:] = self._forward_prefill(
561+
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
562+
attn_metadata)
563+
564+
if has_decode:
565+
output[:num_decode_tokens] = self._forward_decode(
566+
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
567+
return output_padded
568+
569+
570+
def unified_ascend_mla_attention_with_output(
571+
query: torch.Tensor,
572+
key: torch.Tensor,
573+
value: torch.Tensor,
574+
output: torch.Tensor,
575+
layer_name: str,
576+
) -> None:
577+
forward_context: ForwardContext = get_forward_context()
578+
attn_metadata = forward_context.attn_metadata
579+
self = forward_context.no_compile_layers[layer_name]
580+
kv_cache = self.kv_cache[forward_context.virtual_engine]
581+
self.impl.forward(self,
582+
query,
583+
key,
584+
value,
585+
kv_cache,
586+
attn_metadata,
587+
output,
588+
trace_flag=False)
589+
return
590+
591+
592+
def unified_mla_attention_with_output_fake(
593+
query: torch.Tensor,
594+
key: torch.Tensor,
595+
value: torch.Tensor,
596+
output: torch.Tensor,
597+
layer_name: str,
598+
) -> None:
599+
return
600+
601+
602+
direct_register_custom_op(
603+
op_name="unified_ascend_mla_attention_with_output",
604+
op_func=unified_ascend_mla_attention_with_output,
605+
mutates_args=["output"],
606+
fake_impl=unified_mla_attention_with_output_fake,
607+
dispatch_key="PrivateUse1",
608+
)

vllm_ascend/models/deepseek_v2.py

+36
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from torch import nn
3434
from transformers import PretrainedConfig
3535
from vllm.attention import Attention, AttentionMetadata
36+
from vllm.compilation.decorators import support_torch_compile
3637
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
3738
get_current_vllm_config)
3839
from vllm.distributed import (get_dp_group, get_pp_group,
@@ -133,12 +134,46 @@ def __init__(
133134
self.dp_size = get_dp_group().world_size
134135
batch_size = vllm_config.scheduler_config.max_num_seqs
135136
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", '0')) == 1
137+
additional_config = vllm_config.additional_config
138+
self.enable_graph_mode = False
139+
if additional_config:
140+
self.enable_graph_mode = additional_config.get(
141+
"enable_graph_mode", False)
136142

137143
params_dtype = torch.get_default_dtype()
138144
self.final_hidden_states = torch.zeros(
139145
[batch_size, config.hidden_size], dtype=params_dtype, device="npu")
140146
self.tp_group = get_tp_group().device_group
141147

148+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
149+
if self.enable_graph_mode:
150+
return self._forward(hidden_states)
151+
else:
152+
return self._forward_eager(hidden_states)
153+
154+
def _forward_eager(self, hidden_states: torch.Tensor) -> torch.Tensor:
155+
num_tokens, hidden_dim = hidden_states.shape
156+
hidden_states = hidden_states.view(-1, hidden_dim)
157+
158+
if self.n_shared_experts is not None:
159+
shared_output = self.shared_experts(hidden_states)
160+
161+
# router_logits: (num_tokens, n_experts)
162+
router_logits, _ = self.gate(hidden_states)
163+
final_hidden_states = self.experts(
164+
hidden_states=hidden_states,
165+
router_logits=router_logits) * self.routed_scaling_factor
166+
167+
# NOTE(Yizhou): Quite strange that the order of these two operations
168+
# is reversed in the original vLLM code
169+
if self.tp_size > 1:
170+
final_hidden_states = tensor_model_parallel_all_reduce(
171+
final_hidden_states)
172+
if shared_output is not None:
173+
final_hidden_states = final_hidden_states + shared_output
174+
175+
return final_hidden_states.view(num_tokens, hidden_dim)
176+
142177
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
143178
attn_metadata = get_forward_context().attn_metadata
144179
if attn_metadata is None:
@@ -454,6 +489,7 @@ def forward(
454489
return hidden_states, residual
455490

456491

492+
@support_torch_compile
457493
class CustomDeepseekV2Model(nn.Module):
458494

459495
fall_back_to_pt_during_load = False

vllm_ascend/ops/fused_moe.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def fused_experts(
188188
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
189189
# ], "Only float32, float16, and bfloat16 are supported"
190190

191-
if expert_map is not None:
191+
if True or expert_map is not None:
192192
# Generate token indices and flatten
193193
token_indices = (torch.arange(num_tokens,
194194
device=device,
@@ -198,7 +198,7 @@ def fused_experts(
198198
# Flatten token-to-expert mappings and map to local experts
199199
weights_flat = topk_weights.view(-1)
200200
experts_flat = topk_ids.view(-1)
201-
local_experts_flat = expert_map[experts_flat]
201+
local_experts_flat = expert_map[experts_flat] if expert_map is not None else experts_flat
202202

203203
# Filter valid token-expert pairs
204204
mask = local_experts_flat != -1
@@ -270,7 +270,7 @@ def fused_experts(
270270

271271
down_out_list = torch.cat(down_out_list, dim=0)
272272

273-
if expert_map is not None:
273+
if True or expert_map is not None:
274274
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)
275275

276276
final_hidden_states = torch.zeros(*original_shape,
@@ -614,8 +614,8 @@ def __init__(self,
614614
def forward(self,
615615
hidden_states: torch.Tensor,
616616
router_logits: torch.Tensor,
617-
is_prefill: bool,
618-
top_k=None):
617+
is_prefill: bool = True,
618+
top_k: Optional[int] = None):
619619
assert self.quant_method is not None
620620

621621
if top_k:

vllm_ascend/platform.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
144144
"using only ACL Graph mode")
145145
compilation_config.use_inductor = False
146146
compilation_config.splitting_ops.extend(
147-
["vllm.unified_ascend_attention_with_output"])
147+
["vllm.unified_ascend_attention_with_output",
148+
"vllm.unified_ascend_mla_attention_with_output"])
148149

149150
if vllm_config.additional_config is not None:
150151
enable_graph_mode = vllm_config.additional_config.get(

0 commit comments

Comments
 (0)