Skip to content

Commit 11dfdf2

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Kernel] DeepGemm MoE : Integrate triton permute / unpermute kernels (#20903)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent fdc5b43 commit 11dfdf2

File tree

10 files changed

+491
-59
lines changed

10 files changed

+491
-59
lines changed

tests/kernels/moe/modular_kernel_tools/cli_args.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def to_quant_torch_dtype(s: str) -> torch.dtype:
8585
help="num topk")
8686
parser.add_argument(
8787
"--fused-moe-chunk-size",
88-
nargs="+",
8988
type=int,
9089
help="Fused moe chunk size used for the non-batched fused experts impl."
9190
)

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def workspace_shapes(
239239
topk: int,
240240
global_num_experts: int,
241241
local_num_experts: int,
242+
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
242243
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
243244
assert a.dim() == 2
244245
# FIXME (varun): We should be able to dispatch only from the leader

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,21 @@ def workspace_shapes(
116116
topk: int,
117117
global_num_experts: int,
118118
local_num_experts: int,
119+
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
119120
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
120121
# Note: the deep gemm workspaces are strictly larger than the triton
121122
# workspaces so we can be pessimistic here and allocate for DeepGemm
122123
# even if we fall back to triton later, e.g. if expert maps are set.
123124
if self.allow_deep_gemm:
124125
assert self.batched_deep_gemm_experts is not None
125126
return self.batched_deep_gemm_experts.workspace_shapes(
126-
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
127+
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
128+
expert_tokens_metadata)
127129
else:
128130
assert self.batched_triton_experts is not None
129131
return self.batched_triton_experts.workspace_shapes(
130-
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
132+
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
133+
expert_tokens_metadata)
131134

132135
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
133136
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def workspace_shapes(
271271
topk: int,
272272
global_num_experts: int,
273273
local_num_experts: int,
274+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
274275
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
275276
workspace1: tuple[int, ...] = ()
276277
workspace2: tuple[int, ...] = ()

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 53 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
99
from vllm.logger import init_logger
1010
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
11-
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
12-
_moe_permute)
11+
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
12+
compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce)
1313
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
1414
MoEPrepareAndFinalizeNoEP)
1515
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
16-
TopKWeightAndReduceContiguous, TopKWeightAndReduceNoOP)
16+
TopKWeightAndReduceNoOP)
1717
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
1818
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1919
per_token_group_quant_fp8)
20-
from vllm.utils import has_deep_gemm, round_up
20+
from vllm.utils import has_deep_gemm
2121
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
2222

2323
logger = init_logger(__name__)
@@ -93,18 +93,25 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
9393
return TopKWeightAndReduceNoOP()
9494

9595
def workspace_shapes(
96-
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
97-
topk: int, global_num_experts: int, local_num_experts: int
96+
self,
97+
a: torch.Tensor,
98+
aq: torch.Tensor,
99+
M: int,
100+
N: int,
101+
K: int,
102+
topk: int,
103+
global_num_experts: int,
104+
local_num_experts: int,
105+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
98106
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
99107
assert self.block_shape is not None
100-
# We use global_num_experts due to how moe_align_block_size handles
101-
# expert_maps.
102-
num_experts = global_num_experts
103108
block_m = self.block_shape[0]
104-
M_sum = (M * topk) + num_experts * (block_m - 1)
105-
M_sum = round_up(M_sum, block_m)
106-
workspace1 = (M_sum, max(N // 2, K))
107-
workspace2 = (M_sum, max(N, K))
109+
M_sum = compute_aligned_M(M, topk, local_num_experts, block_m,
110+
expert_tokens_meta)
111+
assert M_sum % block_m == 0
112+
113+
workspace1 = (M_sum, max(N, K))
114+
workspace2 = (M_sum, max(N // 2, K))
108115
output = (M, K)
109116
return (workspace1, workspace2, output, a.dtype)
110117

@@ -131,43 +138,40 @@ def apply(
131138
apply_router_weight_on_input: bool,
132139
):
133140
assert self.block_shape is not None
141+
assert a1q_scale is not None
134142

135143
a1q = hidden_states
136144
_, N, K = w1.size()
137-
M, _ = output.size()
138-
num_topk = topk_ids.size(1)
139145

146+
local_num_experts = w1.size(0)
140147
if global_num_experts == -1:
141-
global_num_experts = w1.size(0)
148+
global_num_experts = local_num_experts
142149

143150
assert w2.size(1) == K
144151

145-
a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute(
146-
a1q,
147-
a1q_scale,
148-
topk_ids,
149-
global_num_experts,
150-
expert_map,
151-
self.block_shape[0],
152-
)
153-
154-
if expert_map is not None:
155-
# DeepGemm (Grouped Contiguous) kernel needs a valid B index
156-
# for all rows of A. To that effect, simply compute with
157-
# the 0th weight matrix.
158-
# Note that this relies on the fact that corresponding topk
159-
# weights would be 0 during weight multiplication.
160-
expert_ids = torch.where(expert_ids == -1, 0, expert_ids)
161-
162-
# Note: M_sum is different than the pre-permuted shape of a1q.
163-
M_sum = a1q.size(0)
164-
165-
mm1_out = _resize_cache(workspace2, (M_sum, N))
166-
act_out = _resize_cache(workspace13, (M_sum, N // 2))
167-
quant_out = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
152+
M_sum = compute_aligned_M(M=topk_ids.size(0),
153+
num_topk=topk_ids.size(1),
154+
local_num_experts=local_num_experts,
155+
alignment=deep_gemm_block_shape()[0],
156+
expert_tokens_meta=expert_tokens_meta)
157+
158+
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
159+
(M_sum, K))
160+
mm1_out = _resize_cache(workspace13, (M_sum, N))
161+
act_out = _resize_cache(workspace2, (M_sum, N // 2))
162+
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
168163
(M_sum, N // 2))
169-
mm2_out = _resize_cache(workspace13, (M_sum, K))
170-
perm_out = _resize_cache(workspace2, (M * num_topk, K))
164+
mm2_out = _resize_cache(workspace2, (M_sum, K))
165+
166+
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
167+
aq=a1q,
168+
aq_scale=a1q_scale,
169+
topk_ids=topk_ids,
170+
local_num_experts=local_num_experts,
171+
expert_map=expert_map,
172+
expert_tokens_meta=expert_tokens_meta,
173+
aq_out=a1q_perm)
174+
assert a1q.size(0) == M_sum
171175

172176
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
173177
mm1_out, expert_ids)
@@ -183,14 +187,15 @@ def apply(
183187
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
184188
mm2_out, expert_ids)
185189

186-
torch.index_select(mm2_out, 0, inv_perm, out=perm_out)
190+
if apply_router_weight_on_input:
191+
topk_weights = torch.ones_like(topk_weights)
187192

188-
TopKWeightAndReduceContiguous().apply(
189-
output=output,
190-
fused_expert_output=perm_out,
191-
topk_weights=topk_weights,
192-
topk_ids=topk_ids,
193-
apply_router_weight_on_input=apply_router_weight_on_input)
193+
deepgemm_unpermute_and_reduce(a=mm2_out,
194+
topk_ids=topk_ids,
195+
topk_weights=topk_weights,
196+
inv_perm=inv_perm,
197+
expert_map=expert_map,
198+
output=output)
194199

195200

196201
def deep_gemm_moe_fp8(

0 commit comments

Comments
 (0)