Skip to content

[Kernel] DeepGemm MoE : Integrate triton permute / unpermute kernels #20903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/kernels/moe/modular_kernel_tools/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def to_quant_torch_dtype(s: str) -> torch.dtype:
help="num topk")
parser.add_argument(
"--fused-moe-chunk-size",
nargs="+",
type=int,
help="Fused moe chunk size used for the non-batched fused experts impl."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def workspace_shapes(
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
# FIXME (varun): We should be able to dispatch only from the leader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,21 @@ def workspace_shapes(
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm:
assert self.batched_deep_gemm_experts is not None
return self.batched_deep_gemm_experts.workspace_shapes(
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
expert_tokens_metadata)
else:
assert self.batched_triton_experts is not None
return self.batched_triton_experts.workspace_shapes(
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
expert_tokens_metadata)

def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def workspace_shapes(
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
Expand Down
101 changes: 53 additions & 48 deletions vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute)
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous, TopKWeightAndReduceNoOP)
TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import has_deep_gemm, round_up
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous

logger = init_logger(__name__)
Expand Down Expand Up @@ -93,18 +93,25 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()

def workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert self.block_shape is not None
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
num_experts = global_num_experts
block_m = self.block_shape[0]
M_sum = (M * topk) + num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)
workspace1 = (M_sum, max(N // 2, K))
workspace2 = (M_sum, max(N, K))
M_sum = compute_aligned_M(M, topk, local_num_experts, block_m,
expert_tokens_meta)
assert M_sum % block_m == 0

workspace1 = (M_sum, max(N, K))
workspace2 = (M_sum, max(N // 2, K))
output = (M, K)
return (workspace1, workspace2, output, a.dtype)

Expand All @@ -131,43 +138,40 @@ def apply(
apply_router_weight_on_input: bool,
):
assert self.block_shape is not None
assert a1q_scale is not None

a1q = hidden_states
_, N, K = w1.size()
M, _ = output.size()
num_topk = topk_ids.size(1)

local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = w1.size(0)
global_num_experts = local_num_experts

assert w2.size(1) == K

a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute(
a1q,
a1q_scale,
topk_ids,
global_num_experts,
expert_map,
self.block_shape[0],
)

if expert_map is not None:
# DeepGemm (Grouped Contiguous) kernel needs a valid B index
# for all rows of A. To that effect, simply compute with
# the 0th weight matrix.
# Note that this relies on the fact that corresponding topk
# weights would be 0 during weight multiplication.
expert_ids = torch.where(expert_ids == -1, 0, expert_ids)

# Note: M_sum is different than the pre-permuted shape of a1q.
M_sum = a1q.size(0)

mm1_out = _resize_cache(workspace2, (M_sum, N))
act_out = _resize_cache(workspace13, (M_sum, N // 2))
quant_out = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
M_sum = compute_aligned_M(M=topk_ids.size(0),
num_topk=topk_ids.size(1),
local_num_experts=local_num_experts,
alignment=deep_gemm_block_shape()[0],
expert_tokens_meta=expert_tokens_meta)

a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
(M_sum, K))
mm1_out = _resize_cache(workspace13, (M_sum, N))
act_out = _resize_cache(workspace2, (M_sum, N // 2))
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(M_sum, N // 2))
mm2_out = _resize_cache(workspace13, (M_sum, K))
perm_out = _resize_cache(workspace2, (M * num_topk, K))
mm2_out = _resize_cache(workspace2, (M_sum, K))

a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1q_scale,
topk_ids=topk_ids,
local_num_experts=local_num_experts,
expert_map=expert_map,
expert_tokens_meta=expert_tokens_meta,
aq_out=a1q_perm)
assert a1q.size(0) == M_sum

m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
mm1_out, expert_ids)
Expand All @@ -183,14 +187,15 @@ def apply(
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
mm2_out, expert_ids)

torch.index_select(mm2_out, 0, inv_perm, out=perm_out)
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)

TopKWeightAndReduceContiguous().apply(
output=output,
fused_expert_output=perm_out,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input)
deepgemm_unpermute_and_reduce(a=mm2_out,
topk_ids=topk_ids,
topk_weights=topk_weights,
inv_perm=inv_perm,
expert_map=expert_map,
output=output)


def deep_gemm_moe_fp8(
Expand Down
Loading