diff --git a/tests/kernels/moe/modular_kernel_tools/cli_args.py b/tests/kernels/moe/modular_kernel_tools/cli_args.py index 261f1eb6e5c3..b95d87cd04f5 100644 --- a/tests/kernels/moe/modular_kernel_tools/cli_args.py +++ b/tests/kernels/moe/modular_kernel_tools/cli_args.py @@ -85,7 +85,6 @@ def to_quant_torch_dtype(s: str) -> torch.dtype: help="num topk") parser.add_argument( "--fused-moe-chunk-size", - nargs="+", type=int, help="Fused moe chunk size used for the non-batched fused experts impl." ) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 0b3943292152..e61d350388ea 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -239,6 +239,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_metadata: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 # FIXME (varun): We should be able to dispatch only from the leader diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 12df9bb34d25..1a63b3237343 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -116,6 +116,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_metadata: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm @@ -123,11 +124,13 @@ def workspace_shapes( if self.allow_deep_gemm: assert self.batched_deep_gemm_experts is not None return self.batched_deep_gemm_experts.workspace_shapes( - a, aq, M, N, K, topk, global_num_experts, local_num_experts) + a, aq, M, N, K, topk, global_num_experts, local_num_experts, + expert_tokens_metadata) else: assert self.batched_triton_experts is not None return self.batched_triton_experts.workspace_shapes( - a, aq, M, N, K, topk, global_num_experts, local_num_experts) + a, aq, M, N, K, topk, global_num_experts, local_num_experts, + expert_tokens_metadata) def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index e479f1b40444..d09161ead464 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -271,6 +271,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: workspace1: tuple[int, ...] = () workspace2: tuple[int, ...] = () diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index cc5e7cf57147..bb462938a392 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -8,16 +8,16 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - _moe_permute) +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( + compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceContiguous, TopKWeightAndReduceNoOP) + TopKWeightAndReduceNoOP) from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) -from vllm.utils import has_deep_gemm, round_up +from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous logger = init_logger(__name__) @@ -93,18 +93,25 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() def workspace_shapes( - self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, - topk: int, global_num_experts: int, local_num_experts: int + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert self.block_shape is not None - # We use global_num_experts due to how moe_align_block_size handles - # expert_maps. - num_experts = global_num_experts block_m = self.block_shape[0] - M_sum = (M * topk) + num_experts * (block_m - 1) - M_sum = round_up(M_sum, block_m) - workspace1 = (M_sum, max(N // 2, K)) - workspace2 = (M_sum, max(N, K)) + M_sum = compute_aligned_M(M, topk, local_num_experts, block_m, + expert_tokens_meta) + assert M_sum % block_m == 0 + + workspace1 = (M_sum, max(N, K)) + workspace2 = (M_sum, max(N // 2, K)) output = (M, K) return (workspace1, workspace2, output, a.dtype) @@ -131,43 +138,40 @@ def apply( apply_router_weight_on_input: bool, ): assert self.block_shape is not None + assert a1q_scale is not None a1q = hidden_states _, N, K = w1.size() - M, _ = output.size() - num_topk = topk_ids.size(1) + local_num_experts = w1.size(0) if global_num_experts == -1: - global_num_experts = w1.size(0) + global_num_experts = local_num_experts assert w2.size(1) == K - a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute( - a1q, - a1q_scale, - topk_ids, - global_num_experts, - expert_map, - self.block_shape[0], - ) - - if expert_map is not None: - # DeepGemm (Grouped Contiguous) kernel needs a valid B index - # for all rows of A. To that effect, simply compute with - # the 0th weight matrix. - # Note that this relies on the fact that corresponding topk - # weights would be 0 during weight multiplication. - expert_ids = torch.where(expert_ids == -1, 0, expert_ids) - - # Note: M_sum is different than the pre-permuted shape of a1q. - M_sum = a1q.size(0) - - mm1_out = _resize_cache(workspace2, (M_sum, N)) - act_out = _resize_cache(workspace13, (M_sum, N // 2)) - quant_out = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), + M_sum = compute_aligned_M(M=topk_ids.size(0), + num_topk=topk_ids.size(1), + local_num_experts=local_num_experts, + alignment=deep_gemm_block_shape()[0], + expert_tokens_meta=expert_tokens_meta) + + a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), + (M_sum, K)) + mm1_out = _resize_cache(workspace13, (M_sum, N)) + act_out = _resize_cache(workspace2, (M_sum, N // 2)) + quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)) - mm2_out = _resize_cache(workspace13, (M_sum, K)) - perm_out = _resize_cache(workspace2, (M * num_topk, K)) + mm2_out = _resize_cache(workspace2, (M_sum, K)) + + a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute( + aq=a1q, + aq_scale=a1q_scale, + topk_ids=topk_ids, + local_num_experts=local_num_experts, + expert_map=expert_map, + expert_tokens_meta=expert_tokens_meta, + aq_out=a1q_perm) + assert a1q.size(0) == M_sum m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) @@ -183,14 +187,15 @@ def apply( m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids) - torch.index_select(mm2_out, 0, inv_perm, out=perm_out) + if apply_router_weight_on_input: + topk_weights = torch.ones_like(topk_weights) - TopKWeightAndReduceContiguous().apply( - output=output, - fused_expert_output=perm_out, - topk_weights=topk_weights, - topk_ids=topk_ids, - apply_router_weight_on_input=apply_router_weight_on_input) + deepgemm_unpermute_and_reduce(a=mm2_out, + topk_ids=topk_ids, + topk_weights=topk_weights, + inv_perm=inv_perm, + expert_map=expert_map, + output=output) def deep_gemm_moe_fp8( diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py new file mode 100644 index 000000000000..8cc5a747c673 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py @@ -0,0 +1,413 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Taken from https://github.com/ModelTC/LightLLM/blob/8ed97c74c18f11505b048b1ba00ba5c0cef8bff6/lightllm/common/fused_moe/deepep_scatter_gather.py +and updated to fit vllm needs and terminology. +""" + +import functools +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens +from vllm.triton_utils import tl, triton +from vllm.utils import round_up + + +@functools.cache +def deep_gemm_block_shape() -> list[int]: + # Lazy import to avoid CUDA initialization problems. + import deep_gemm as dg + block = dg.get_m_alignment_for_contiguous_layout() + return [block, block] + + +def expert_num_tokens_round_up_and_sum(expert_num_tokens: torch.Tensor, + alignment: int) -> int: + # Round up each element in expert_num_tokens to the nearest multiple of + # alignment. + ent = (expert_num_tokens.to(torch.int64) + + (alignment - 1)) // alignment * alignment + return torch.sum(ent).item() + + +def compute_aligned_M(M: int, num_topk: int, local_num_experts: int, + alignment: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata]): + + if ((expert_tokens_meta is not None) + and (expert_tokens_meta.expert_num_tokens_cpu is not None)): + return expert_num_tokens_round_up_and_sum( + expert_tokens_meta.expert_num_tokens_cpu, alignment=alignment) + + # expert_num_tokens information is not available on the cpu. + # compute the max required size. + M_sum = (M * num_topk) + local_num_experts * (alignment - 1) + M_sum = round_up(M_sum, alignment) + return M_sum + + +@triton.jit +def apply_expert_map(expert_id, expert_map): + if expert_id != -1: + expert_id = tl.load(expert_map + expert_id).to(tl.int64) + return expert_id + + +@triton.jit +def round_up_128(x: int) -> int: + y = 128 + return ((x + y - 1) // y) * y + + +@triton.jit +def _fwd_kernel_ep_scatter_1( + num_recv_tokens_per_expert, + expert_start_loc, + m_indices, + num_experts: tl.constexpr, + BLOCK_E: tl.constexpr, + BLOCK_EXPERT_NUM: tl.constexpr, +): + cur_expert = tl.program_id(0) + + offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM) + tokens_per_expert = tl.load(num_recv_tokens_per_expert + offset_cumsum, + mask=offset_cumsum < num_experts, + other=0) + tokens_per_expert = round_up_128(tokens_per_expert) + cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert + tl.store(expert_start_loc + offset_cumsum, + cumsum, + mask=offset_cumsum < num_experts) + + cur_expert_start = tl.load(expert_start_loc + cur_expert) + cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) + + m_indices_start_ptr = m_indices + cur_expert_start + off_expert = tl.arange(0, BLOCK_E) + + for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4): + tl.store( + m_indices_start_ptr + start_m + off_expert, + cur_expert, + ) + + +@triton.jit +def _fwd_kernel_ep_scatter_2( + total_token_num, + expert_start_loc, + recv_x, + recv_x_stride0, + recv_x_stride1, + recv_x_scale, + recv_x_scale_stride0, + recv_x_scale_stride1, + recv_topk, + recv_topk_stride0, + recv_topk_stride1, + output_tensor, + output_tensor_stride0, + output_tensor_stride1, + output_tensor_scale, + output_tensor_scale_stride0, + output_tensor_scale_stride1, + output_index, + output_index_stride0, + output_index_stride1, + topk_num: tl.constexpr, + expert_map, + HAS_EXPERT_MAP: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + HIDDEN_SIZE_PAD: tl.constexpr, + SCALE_HIDDEN_SIZE: tl.constexpr, + SCALE_HIDDEN_SIZE_PAD: tl.constexpr, +): + start_token_id = tl.program_id(0) + grid_num = tl.num_programs(0) + + offset_in = tl.arange(0, HIDDEN_SIZE_PAD) + mask = offset_in < HIDDEN_SIZE + + offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD) + mask_s = offset_in_s < SCALE_HIDDEN_SIZE + + for token_id in range(start_token_id, total_token_num, grid_num): + to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, + mask=mask) + to_copy_s = tl.load(recv_x_scale + token_id * recv_x_scale_stride0 + + offset_in_s, + mask=mask_s) + + for topk_index in tl.range(0, topk_num, 1, num_stages=4): + expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + + topk_index) + + if HAS_EXPERT_MAP: + expert_id = apply_expert_map(expert_id, expert_map) + + if expert_id >= 0: + dest_token_index = tl.atomic_add(expert_start_loc + expert_id, + 1) + tl.store( + output_index + token_id * output_index_stride0 + + topk_index, dest_token_index) + output_tensor_ptr = (output_tensor + + dest_token_index * output_tensor_stride0) + output_tensor_scale_ptr = ( + output_tensor_scale + + dest_token_index * output_tensor_scale_stride0) + tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask) + tl.store(output_tensor_scale_ptr + offset_in_s, + to_copy_s, + mask=mask_s) + + +@torch.no_grad() +def ep_scatter( + recv_x: torch.Tensor, + recv_x_scale: torch.Tensor, + recv_topk: torch.Tensor, + num_recv_tokens_per_expert: torch.Tensor, + expert_map: Optional[torch.Tensor], + expert_start_loc: torch.Tensor, + output_tensor: torch.Tensor, + output_tensor_scale: torch.Tensor, + m_indices: torch.Tensor, + output_index: torch.Tensor, +): + BLOCK_E = 128 # token num of per expert is aligned to 128 + BLOCK_D = 128 # block size of quantization + num_warps = 8 + num_experts = num_recv_tokens_per_expert.shape[0] + hidden_size = recv_x.shape[1] + # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts) + grid = num_experts + + assert m_indices.shape[0] % BLOCK_E == 0 + + _fwd_kernel_ep_scatter_1[(grid, )]( + num_recv_tokens_per_expert, + expert_start_loc, + m_indices, + num_experts=num_experts, + num_warps=num_warps, + BLOCK_E=BLOCK_E, + BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts), + ) + + grid = min(recv_topk.shape[0], 1024 * 8) + + _fwd_kernel_ep_scatter_2[(grid, )]( + recv_topk.shape[0], + expert_start_loc, + recv_x, + recv_x.stride(0), + recv_x.stride(1), + recv_x_scale, + recv_x_scale.stride(0), + recv_x_scale.stride(1), + recv_topk, + recv_topk.stride(0), + recv_topk.stride(1), + output_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + output_tensor_scale, + output_tensor_scale.stride(0), + output_tensor_scale.stride(1), + output_index, + output_index.stride(0), + output_index.stride(1), + topk_num=recv_topk.shape[1], + expert_map=expert_map, + HAS_EXPERT_MAP=expert_map is not None, + num_warps=num_warps, + HIDDEN_SIZE=hidden_size, + HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), + SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D, + SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D), + ) + return + + +@triton.jit +def _fwd_kernel_ep_gather( + total_token_num, + input_tensor, + input_tensor_stride0, + input_tensor_stride1, + recv_topk_ids, + recv_topk_ids_stride0, + recv_topk_ids_stride1, + recv_topk_weight, + recv_topk_weight_stride0, + recv_topk_weight_stride1, + input_index, + input_index_stride0, + input_index_stride1, + output_tensor, + output_tensor_stride0, + output_tensor_stride1, + topk_num: tl.constexpr, + expert_map, + HAS_EXPERT_MAP: tl.constexpr, + BLOCK_D: tl.constexpr, +): + cur_block = tl.program_id(0) + start_cur_token = tl.program_id(1) + grid_num = tl.num_programs(1) + + for cur_token in range(start_cur_token, total_token_num, grid_num): + off_d = tl.arange(0, BLOCK_D) + accumulator = tl.zeros([BLOCK_D], dtype=tl.float32) + for topk_index in range(0, topk_num): + expert_id = tl.load(recv_topk_ids + + cur_token * recv_topk_ids_stride0 + topk_index) + + if HAS_EXPERT_MAP: + expert_id = apply_expert_map(expert_id, expert_map) + + if expert_id >= 0: + source_token_index = tl.load(input_index + + cur_token * input_index_stride0 + + topk_index) + acc_weight = tl.load(recv_topk_weight + + cur_token * recv_topk_weight_stride0 + + topk_index) + tmp = tl.load(input_tensor + + source_token_index * input_tensor_stride0 + + cur_block * BLOCK_D + off_d) + accumulator += tmp.to(tl.float32) * acc_weight + + tl.store( + output_tensor + cur_token * output_tensor_stride0 + + cur_block * BLOCK_D + off_d, + accumulator.to(output_tensor.dtype.element_ty), + ) + + +@torch.no_grad() +def ep_gather( + input_tensor: torch.Tensor, + recv_topk_ids: torch.Tensor, + recv_topk_weight: torch.Tensor, + input_index: torch.Tensor, + expert_map: Optional[torch.Tensor], + output_tensor: torch.Tensor, +): + num_warps = 2 + num_tokens = output_tensor.shape[0] + hidden_size = input_tensor.shape[1] + BLOCK_D = min(hidden_size, 1024) + assert hidden_size % BLOCK_D == 0 + grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024)) + + _fwd_kernel_ep_gather[grid]( + num_tokens, + input_tensor, + input_tensor.stride(0), + input_tensor.stride(1), + recv_topk_ids, + recv_topk_ids.stride(0), + recv_topk_ids.stride(1), + recv_topk_weight, + recv_topk_weight.stride(0), + recv_topk_weight.stride(1), + input_index, + input_index.stride(0), + input_index.stride(1), + output_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + topk_num=recv_topk_ids.shape[1], + expert_map=expert_map, + HAS_EXPERT_MAP=expert_map is not None, + num_warps=num_warps, + BLOCK_D=BLOCK_D, + ) + return + + +def deepgemm_moe_permute(aq: torch.Tensor, + aq_scale: torch.Tensor, + topk_ids: torch.Tensor, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + aq_out: Optional[torch.Tensor] = None): + + assert aq.ndim == 2 + assert topk_ids.dtype.is_signed, ( + "The kernel uses -1 to represent invalid topk_ids") + H = aq.size(1) + device = aq.device + + block_m = deep_gemm_block_shape()[0] + block_k = deep_gemm_block_shape()[1] + + M_sum = compute_aligned_M(M=topk_ids.size(0), + num_topk=topk_ids.size(1), + local_num_experts=local_num_experts, + alignment=block_m, + expert_tokens_meta=expert_tokens_meta) + + expert_start_loc = torch.empty((local_num_experts), + device=device, + dtype=torch.int32) + + assert aq_out is None or aq_out.shape == (M_sum, H) + if aq_out is None: + aq_out = torch.empty((M_sum, H), device=device, dtype=aq.dtype) + + aq_scale_out = torch.empty((M_sum, H // block_k), + device=device, + dtype=torch.float32) + + maybe_has_empty_blocks = ((expert_tokens_meta is None) + or (expert_tokens_meta.expert_num_tokens_cpu + is None)) + expert_ids_init = torch.zeros if maybe_has_empty_blocks else torch.empty + + expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32) + inv_perm = torch.empty(topk_ids.shape, device=device, dtype=torch.int32) + + expert_num_tokens = None + if expert_tokens_meta is not None: + expert_num_tokens = expert_tokens_meta.expert_num_tokens + else: + expert_num_tokens = count_expert_num_tokens(topk_ids, + local_num_experts, + expert_map) + + ep_scatter(recv_x=aq, + recv_x_scale=aq_scale, + recv_topk=topk_ids, + num_recv_tokens_per_expert=expert_num_tokens, + expert_start_loc=expert_start_loc, + expert_map=expert_map, + output_tensor=aq_out, + output_tensor_scale=aq_scale_out, + m_indices=expert_ids, + output_index=inv_perm) + + return aq_out, aq_scale_out, expert_ids, inv_perm + + +def deepgemm_unpermute_and_reduce( + a: torch.Tensor, # Grouped gemm output + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + inv_perm: torch.Tensor, + expert_map: Optional[torch.Tensor], + output: torch.Tensor): + + return ep_gather(input_tensor=a, + recv_topk_ids=topk_ids, + recv_topk_weight=topk_weights, + input_index=inv_perm, + expert_map=expert_map, + output_tensor=output) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index b311ef1ac1cb..ab8a281b3901 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -677,6 +677,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 num_dp = self.num_dispatchers @@ -889,6 +890,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 num_dp = self.num_dispatchers diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f0bffc7dae27..1d6431c6cf12 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1618,6 +1618,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: workspace1 = (M, topk, max(N // 2, K)) workspace2 = (M, topk, max(N, K)) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 028eee241786..bc4eb3b1932a 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -317,6 +317,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: """ Compute the shapes for the temporary and final outputs of the two gemms @@ -479,7 +480,8 @@ def _do_fused_experts(self, fused_out: Optional[torch.Tensor], (workspace13_shape, workspace2_shape, fused_out_shape, workspace_dtype) = self.fused_experts.workspace_shapes( - a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts) + a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, + expert_tokens_meta) # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. @@ -572,10 +574,9 @@ def _maybe_chunk_fused_experts( assert num_chunks > 1 # Construct the entire output that can then be processed in chunks. - (_, _, fused_out_shape, - _) = self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k, - global_num_experts, - local_num_experts) + (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( + a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, + expert_tokens_meta) fused_out = torch.empty(fused_out_shape, device=a1q.device, dtype=a1.dtype) @@ -613,8 +614,11 @@ def slice_expert_tokens_metadata( need_expert_num_tokens_cpu = ( full_expert_tokens_meta.expert_num_tokens_cpu is not None) if need_expert_num_tokens_cpu: + # This is blocking as some implementations need the count + # on the CPU to determine appropriate input/out fused-moe + # buffers c_expert_num_tokens_cpu = c_expert_num_tokens.to( - "cpu", non_blocking=True) + "cpu", non_blocking=False) return ExpertTokensMetadata( expert_num_tokens=c_expert_num_tokens, diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 2f35c19b7054..51b95c9aa922 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -102,6 +102,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm @@ -110,11 +111,13 @@ def workspace_shapes( or is_blackwell_deep_gemm_used()): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( - a, aq, M, N, K, topk, global_num_experts, local_num_experts) + a, aq, M, N, K, topk, global_num_experts, local_num_experts, + expert_tokens_meta) else: return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk, global_num_experts, - local_num_experts) + local_num_experts, + expert_tokens_meta) def apply( self,