Skip to content

feat: add mtp ut case #2166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
41 changes: 27 additions & 14 deletions tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def model_name():
return "wemaster/deepseek_mtp_main_random_bf16"


@pytest.mark.skipif(
True, reason="TODO: Enable me after test_mtp_correctness is fixed")
def test_mtp_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
Expand All @@ -61,28 +59,43 @@ def test_mtp_correctness(
Compare the outputs of a original LLM and a speculative LLM
should be the same when using mtp speculative decoding.
'''
ref_llm = LLM(model=model_name, max_model_len=256, enforce_eager=True)
ref_llm = LLM(model=model_name,
gpu_memory_utilization=0.5,
max_model_len=256,
enforce_eager=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm

spec_llm = LLM(model=model_name,
trust_remote_code=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
max_model_len=256,
enforce_eager=True)
spec_llm = LLM(
model=model_name,
tensor_parallel_size=1,
max_num_seqs=256,
gpu_memory_utilization=0.6,
distributed_executor_backend="mp",
enable_expert_parallel=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
trust_remote_code=True,
enforce_eager=True,
max_model_len=2000,
additional_config={"ascend_scheduler_config": {
"enabled": True
}})

spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
ref_token_ids = ref_output.outputs[0].token_ids
spec_token_ids = spec_output.outputs[0].token_ids
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,4 @@ def forward(
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, previous_hidden_states,
inputs_embeds, spec_step_idx)
return hidden_states
return hidden_states
3 changes: 2 additions & 1 deletion vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def get_quant_method(self, layer: torch.nn.Module,
elif isinstance(layer, FusedMoE):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return AscendUnquantizedFusedMoEMethod(layer.moe)
return AscendUnquantizedFusedMoEMethod(
layer.moe if hasattr(layer, 'meo') else None)
return AscendFusedMoEMethod(self, prefix,
self.packed_modules_mapping)
elif isinstance(layer, VocabParallelEmbedding):
Expand Down
Loading