Skip to content

Commit 533d77f

Browse files
author
Varun Sundar Rabindranath
committed
fixes
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
1 parent d5f49e8 commit 533d77f

File tree

7 files changed

+22
-17
lines changed

7 files changed

+22
-17
lines changed

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,13 @@ def workspace_shapes(
124124
if self.allow_deep_gemm:
125125
assert self.batched_deep_gemm_experts is not None
126126
return self.batched_deep_gemm_experts.workspace_shapes(
127-
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
127+
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
128+
expert_tokens_metadata)
128129
else:
129130
assert self.batched_triton_experts is not None
130131
return self.batched_triton_experts.workspace_shapes(
131-
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
132+
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
133+
expert_tokens_metadata)
132134

133135
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
134136
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,6 @@ def apply(
141141

142142
a1q = hidden_states
143143
_, N, K = w1.size()
144-
M, _ = output.size()
145-
num_topk = topk_ids.size(1)
146144

147145
local_num_experts = w1.size(0)
148146
if global_num_experts == -1:
@@ -155,7 +153,6 @@ def apply(
155153
local_num_experts=local_num_experts,
156154
alignment=deep_gemm_block_shape()[0],
157155
expert_tokens_meta=expert_tokens_meta)
158-
assert M_sum >= M * num_topk
159156

160157
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
161158
(M_sum, K))
@@ -189,7 +186,6 @@ def apply(
189186
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
190187
mm2_out, expert_ids)
191188

192-
# TODO (varun) : We could probably reshape mm2_out and pass as output
193189
if apply_router_weight_on_input:
194190
topk_weights = torch.ones_like(topk_weights)
195191

vllm/model_executor/layers/fused_moe/deep_gemm_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def round_up_128(x: int) -> int:
3737
def compute_aligned_M(M: int, num_topk: int, local_num_experts: int,
3838
alignment: int,
3939
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
40+
4041
if ((expert_tokens_meta is not None)
4142
and (expert_tokens_meta.expert_num_tokens_cpu is not None)):
4243
return expert_num_tokens_round_up_and_sum(
@@ -336,7 +337,7 @@ def ep_gather(
336337
def deepgemm_moe_permute(aq: torch.Tensor,
337338
aq_scale: torch.Tensor,
338339
topk_ids: torch.Tensor,
339-
local_num_experts: torch.Tensor,
340+
local_num_experts: int,
340341
expert_map: Optional[torch.Tensor],
341342
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
342343
aq_out: Optional[torch.Tensor] = None):
@@ -378,7 +379,7 @@ def deepgemm_moe_permute(aq: torch.Tensor,
378379

379380
expert_num_tokens = None
380381
if expert_tokens_meta is not None:
381-
expert_num_tokens = expert_tokens_meta.expert_num_tokens_gpu
382+
expert_num_tokens = expert_tokens_meta.expert_num_tokens
382383
else:
383384
expert_num_tokens = count_expert_num_tokens(topk_ids,
384385
local_num_experts,

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,7 @@ def workspace_shapes(
891891
topk: int,
892892
global_num_experts: int,
893893
local_num_experts: int,
894+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
894895
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
895896
assert a.dim() == 2
896897
num_dp = self.num_dispatchers

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,8 @@ def _do_fused_experts(self, fused_out: Optional[torch.Tensor],
480480

481481
(workspace13_shape, workspace2_shape, fused_out_shape,
482482
workspace_dtype) = self.fused_experts.workspace_shapes(
483-
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts)
483+
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
484+
expert_tokens_meta)
484485

485486
# We can reuse the memory between cache1 and cache3 because by the
486487
# time we need cache3, we're done with cache1.
@@ -573,10 +574,9 @@ def _maybe_chunk_fused_experts(
573574
assert num_chunks > 1
574575

575576
# Construct the entire output that can then be processed in chunks.
576-
(_, _, fused_out_shape,
577-
_) = self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
578-
global_num_experts,
579-
local_num_experts)
577+
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
578+
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
579+
expert_tokens_meta)
580580
fused_out = torch.empty(fused_out_shape,
581581
device=a1q.device,
582582
dtype=a1.dtype)
@@ -614,8 +614,11 @@ def slice_expert_tokens_metadata(
614614
need_expert_num_tokens_cpu = (
615615
full_expert_tokens_meta.expert_num_tokens_cpu is not None)
616616
if need_expert_num_tokens_cpu:
617+
# This is blocking as some implementations need the count
618+
# on the CPU to determine appropriate input/out fused-moe
619+
# buffers
617620
c_expert_num_tokens_cpu = c_expert_num_tokens.to(
618-
"cpu", non_blocking=True)
621+
"cpu", non_blocking=False)
619622

620623
return ExpertTokensMetadata(
621624
expert_num_tokens=c_expert_num_tokens,

vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def prepare(
111111
# topk_indices_dtype() int32
112112
#
113113
if expert_map is not None:
114-
logger.warn_once(
114+
logger.warning_once(
115115
"The PPLX backend does not support expert mapping. "
116116
"The provided `expert_map` will be ignored.")
117117
expert_map = None #noqa: F841

vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,13 @@ def workspace_shapes(
111111
or is_blackwell_deep_gemm_used()):
112112
assert self.deep_gemm_expert is not None
113113
return self.deep_gemm_expert.workspace_shapes(
114-
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
114+
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
115+
expert_tokens_meta)
115116
else:
116117
return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
117118
global_num_experts,
118-
local_num_experts)
119+
local_num_experts,
120+
expert_tokens_meta)
119121

120122
def apply(
121123
self,

0 commit comments

Comments
 (0)