Skip to content

[WIP] Add support for custom DeepSeek modelling in ACL Graph mode #677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 89 additions & 39 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
Expand Down Expand Up @@ -524,40 +526,48 @@
kv_cache: torch.Tensor,
attn_metadata: M,
output: Optional[torch.Tensor] = None,
trace_flag: bool = True,
) -> torch.Tensor:

assert output is not None, "Output tensor must be provided."

if attn_metadata is None:
# Profiling run.
return output
if trace_flag:
torch.ops.vllm.unified_ascend_mla_attention_with_output(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the attention interface unified_ascend_attention_with_output which already registered in attention_v1.py

query=hidden_states_or_q_c,
key=k_c_normed,
value=k_pe,
output=output,
layer_name=layer.layer_name)
else:
if attn_metadata is None:
# Profiling run.
return output

num_actual_toks = attn_metadata.num_actual_tokens
num_actual_toks = attn_metadata.num_actual_tokens

# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_toks, ...]
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_toks, ...]
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]

# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)

assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None
assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None

has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens

decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
decode_k_pe = k_pe[:num_decode_tokens]
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
decode_k_pe = k_pe[:num_decode_tokens]

prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]

if has_decode:
assert attn_metadata.decode is not None
Expand All @@ -569,17 +579,17 @@
decode_k_pe,
max_seq_len=attn_metadata.decode.max_seq_lens)

if has_prefill:
assert attn_metadata.prefill is not None
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
if has_prefill:
assert attn_metadata.prefill is not None
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]

prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
attn_metadata.prefill.input_positions,

Check failure on line 589 in vllm_ascend/attention/mla_v1.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Item "None" of "AscendMLAPrefillMetadata | None" has no attribute "input_positions" [union-attr]
prefill_q_pe.contiguous(),
prefill_k_pe,
max_seq_len=attn_metadata.prefill.max_seq_lens)

Check failure on line 592 in vllm_ascend/attention/mla_v1.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Item "None" of "AscendMLAPrefillMetadata | None" has no attribute "max_seq_lens" [union-attr]

if kv_cache.numel() > 0:
key = torch.cat([
Expand All @@ -591,13 +601,53 @@
key_cache=kv_cache,
slot_indices=attn_metadata.slot_mapping.flatten())

if has_prefill:
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)

if has_decode:
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)

return output_padded
if has_prefill:
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)

if has_decode:
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
return output_padded


def unified_ascend_mla_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
key,
value,
kv_cache,
attn_metadata,
output,
trace_flag=False)
return


def unified_mla_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return


direct_register_custom_op(
op_name="unified_ascend_mla_attention_with_output",
op_func=unified_ascend_mla_attention_with_output,
mutates_args=["output"],
fake_impl=unified_mla_attention_with_output_fake,
dispatch_key="PrivateUse1",
)
36 changes: 36 additions & 0 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed import (get_dp_group, get_pp_group,
Expand Down Expand Up @@ -206,13 +207,47 @@ def __init__(
self.dp_size = get_dp_group().world_size
batch_size = vllm_config.scheduler_config.max_num_seqs
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", '0')) == 1
additional_config = vllm_config.additional_config
self.enable_graph_mode = False
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)

params_dtype = torch.get_default_dtype()
self.final_hidden_states = torch.zeros(
[batch_size, config.hidden_size], dtype=params_dtype, device="npu")
self.tp_group = get_tp_group().device_group

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.enable_graph_mode:
return self._forward(hidden_states)
else:
return self._forward_eager(hidden_states)

def _forward_eager(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor

# NOTE(Yizhou): Quite strange that the order of these two operations
# is reversed in the original vLLM code
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output

return final_hidden_states.view(num_tokens, hidden_dim)

def _forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we tested torchair on this, access attn_metadata in global context may cause some graph related issue, better confirmed it with torchair

attn_metadata = get_forward_context().attn_metadata
if attn_metadata is None:
# for profile run
Expand Down Expand Up @@ -528,6 +563,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class CustomDeepseekV2Model(nn.Module):

fall_back_to_pt_during_load = False
Expand Down
10 changes: 5 additions & 5 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def fused_experts(
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
# ], "Only float32, float16, and bfloat16 are supported"

if expert_map is not None:
if True or expert_map is not None:
# Generate token indices and flatten
token_indices = (torch.arange(num_tokens,
device=device,
Expand All @@ -198,7 +198,7 @@ def fused_experts(
# Flatten token-to-expert mappings and map to local experts
weights_flat = topk_weights.view(-1)
experts_flat = topk_ids.view(-1)
local_experts_flat = expert_map[experts_flat]
local_experts_flat = expert_map[experts_flat] if expert_map is not None else experts_flat

# Filter valid token-expert pairs
mask = local_experts_flat != -1
Expand Down Expand Up @@ -270,7 +270,7 @@ def fused_experts(

down_out_list = torch.cat(down_out_list, dim=0)

if expert_map is not None:
if True or expert_map is not None:
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)

final_hidden_states = torch.zeros(*original_shape,
Expand Down Expand Up @@ -614,8 +614,8 @@ def __init__(self,
def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_prefill: bool,
top_k=None):
is_prefill: bool = True,
top_k: Optional[int] = None):
assert self.quant_method is not None

if top_k:
Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"using only ACL Graph mode")
compilation_config.use_inductor = False
compilation_config.splitting_ops.extend(
["vllm.unified_ascend_attention_with_output"])
["vllm.unified_ascend_attention_with_output",
"vllm.unified_ascend_mla_attention_with_output"])

if vllm_config.additional_config is not None:
enable_graph_mode = vllm_config.additional_config.get(
Expand Down
Loading