diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py deleted file mode 100644 index 4ed690090144..000000000000 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ /dev/null @@ -1,418 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import argparse -from typing import Any, TypedDict - -import ray -import torch -from transformers import AutoConfig - -from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _moe_permute, - _moe_unpermute_and_reduce, -) -from vllm.model_executor.layers.fused_moe.fused_moe import * -from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import * -from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize -from vllm.platforms import current_platform -from vllm.utils import FlexibleArgumentParser - -FP8_DTYPE = current_platform.fp8_dtype() - - -class BenchmarkConfig(TypedDict): - BLOCK_SIZE_M: int - BLOCK_SIZE_N: int - BLOCK_SIZE_K: int - GROUP_SIZE_M: int - num_warps: int - num_stages: int - - -def benchmark_permute( - num_tokens: int, - num_experts: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - num_iters: int = 100, - use_customized_permute: bool = False, -) -> float: - # init_dtype = torch.float16 if use_fp8_w8a8 else dtype - hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) - # output_hidden_states = torch.empty_like(hidden_states) - if use_fp8_w8a8: - align_block_size = 128 # deepgemm needs 128 m aligned block - qhidden_states, scale = _fp8_quantize(hidden_states, None, None) - else: - align_block_size = None - qhidden_states = hidden_states - - gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) - - input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) - topk_weights, topk_ids, token_expert_indices = fused_topk( - qhidden_states, input_gating, topk, False - ) - - def prepare(i: int): - input_gating.copy_(gating_output[i]) - - def run(): - if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( - moe_permute( - qhidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - token_expert_indices=token_expert_indices, - topk=topk, - n_expert=num_experts, - n_local_expert=num_experts, - expert_map=None, - align_block_size=align_block_size, - ) - ) - else: - ( - permuted_hidden_states, - a1q_scale, - sorted_token_ids, - expert_ids, - inv_perm, - ) = _moe_permute( - qhidden_states, None, topk_ids, num_experts, None, align_block_size - ) - - # JIT compilation & warmup - run() - torch.cuda.synchronize() - - # Capture 10 invocations with CUDA graph - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): - for _ in range(10): - run() - torch.cuda.synchronize() - - # Warmup - for _ in range(5): - graph.replay() - torch.cuda.synchronize() - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - latencies: list[float] = [] - for i in range(num_iters): - prepare(i) - torch.cuda.synchronize() - - start_event.record() - graph.replay() - end_event.record() - end_event.synchronize() - latencies.append(start_event.elapsed_time(end_event)) - avg = sum(latencies) / (num_iters * 10) * 1000 # us - graph.reset() - return avg - - -def benchmark_unpermute( - num_tokens: int, - num_experts: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - num_iters: int = 100, - use_customized_permute: bool = False, -) -> float: - # init_dtype = torch.float16 if use_fp8_w8a8 else dtype - hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) - output_hidden_states = torch.empty_like(hidden_states) - if use_fp8_w8a8: - align_block_size = 128 # deepgemm needs 128 m aligned block - qhidden_states, scale = _fp8_quantize(hidden_states, None, None) - else: - align_block_size = None - qhidden_states = hidden_states - - input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) - - topk_weights, topk_ids, token_expert_indices = fused_topk( - qhidden_states, input_gating, topk, False - ) - - def prepare(): - if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( - moe_permute( - qhidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - token_expert_indices=token_expert_indices, - topk=topk, - n_expert=num_experts, - n_local_expert=num_experts, - expert_map=None, - align_block_size=align_block_size, - ) - ) - # convert to fp16/bf16 as gemm output - return ( - permuted_hidden_states.to(dtype), - first_token_off, - inv_perm_idx, - m_indices, - ) - else: - ( - permuted_qhidden_states, - a1q_scale, - sorted_token_ids, - expert_ids, - inv_perm, - ) = _moe_permute( - qhidden_states, None, topk_ids, num_experts, None, align_block_size - ) - # convert to fp16/bf16 as gemm output - return ( - permuted_qhidden_states.to(dtype), - a1q_scale, - sorted_token_ids, - expert_ids, - inv_perm, - ) - - def run(input: tuple): - if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input - moe_unpermute( - permuted_hidden_states, - topk_weights, - topk_ids, - inv_perm_idx, - first_token_off, - topk, - num_experts, - num_experts, - ) - else: - ( - permuted_hidden_states, - a1q_scale, - sorted_token_ids, - expert_ids, - inv_perm, - ) = input - _moe_unpermute_and_reduce( - output_hidden_states, permuted_hidden_states, inv_perm, topk_weights - ) - - # JIT compilation & warmup - input = prepare() - run(input) - torch.cuda.synchronize() - - # Capture 10 invocations with CUDA graph - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): - for _ in range(10): - run(input) - torch.cuda.synchronize() - - # Warmup - for _ in range(5): - graph.replay() - torch.cuda.synchronize() - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - latencies: list[float] = [] - for i in range(num_iters): - torch.cuda.synchronize() - start_event.record() - graph.replay() - end_event.record() - end_event.synchronize() - latencies.append(start_event.elapsed_time(end_event)) - avg = sum(latencies) / (num_iters * 10) * 1000 # us - graph.reset() - return avg - - -@ray.remote(num_gpus=1) -class BenchmarkWorker: - def __init__(self, seed: int) -> None: - torch.set_default_device("cuda") - current_platform.seed_everything(seed) - self.seed = seed - # Get the device ID to allocate tensors and kernels - # on the respective GPU. This is required for Ray to work - # correctly with multi-GPU tuning on the ROCm platform. - self.device_id = int(ray.get_gpu_ids()[0]) - - def benchmark( - self, - num_tokens: int, - num_experts: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - use_customized_permute: bool = False, - ) -> tuple[dict[str, int], float]: - current_platform.seed_everything(self.seed) - - permute_time = benchmark_permute( - num_tokens, - num_experts, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - num_iters=100, - use_customized_permute=use_customized_permute, - ) - unpermute_time = benchmark_unpermute( - num_tokens, - num_experts, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - num_iters=100, - use_customized_permute=use_customized_permute, - ) - return permute_time, unpermute_time - - -def get_weight_block_size_safety(config, default_value=None): - quantization_config = getattr(config, "quantization_config", {}) - if isinstance(quantization_config, dict): - return quantization_config.get("weight_block_size", default_value) - return default_value - - -def main(args: argparse.Namespace): - print(args) - - config = AutoConfig.from_pretrained( - args.model, trust_remote_code=args.trust_remote_code - ) - if config.architectures[0] == "DbrxForCausalLM": - E = config.ffn_config.moe_num_experts - topk = config.ffn_config.moe_top_k - elif config.architectures[0] == "JambaForCausalLM": - E = config.num_experts - topk = config.num_experts_per_tok - elif ( - config.architectures[0] == "DeepseekV3ForCausalLM" - or config.architectures[0] == "DeepseekV2ForCausalLM" - or config.architectures[0] == "Glm4MoeForCausalLM" - ): - E = config.n_routed_experts - topk = config.num_experts_per_tok - elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]: - E = config.num_experts - topk = config.num_experts_per_tok - - else: - # Support for llama4 - config = config.get_text_config() - # Default: Mixtral. - E = config.num_local_experts - topk = config.num_experts_per_tok - - hidden_size = config.hidden_size - dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype - use_fp8_w8a8 = args.dtype == "fp8_w8a8" - use_int8_w8a16 = args.dtype == "int8_w8a16" - use_customized_permute = args.use_customized_permute - - if args.batch_size is None: - batch_sizes = [ - 1, - 2, - 4, - 8, - 16, - 24, - 32, - 48, - 64, - 96, - 128, - 256, - 512, - 1024, - 1536, - 2048, - 3072, - 4096, - ] - else: - batch_sizes = [args.batch_size] - - ray.init() - num_gpus = int(ray.available_resources()["GPU"]) - workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] - - def _distribute(method: str, inputs: list[Any]) -> list[Any]: - outputs = [] - worker_idx = 0 - for input_args in inputs: - worker = workers[worker_idx] - worker_method = getattr(worker, method) - output = worker_method.remote(*input_args) - outputs.append(output) - worker_idx = (worker_idx + 1) % num_gpus - return ray.get(outputs) - - outputs = _distribute( - "benchmark", - [ - ( - batch_size, - E, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - use_customized_permute, - ) - for batch_size in batch_sizes - ], - ) - - for batch_size, (permute, unpermute) in zip(batch_sizes, outputs): - print(f"Batch size: {batch_size}") - print(f"Permute time: {permute:.2f} us") - print(f"Unpermute time: {unpermute:.2f} us") - - -if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser.add_argument( - "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" - ) - parser.add_argument( - "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" - ) - parser.add_argument("--use-customized-permute", action="store_true") - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--batch-size", type=int, required=False) - parser.add_argument("--trust-remote-code", action="store_true") - args = parser.parse_args() - - main(args) diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 661730c96867..c58756d98c0d 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -24,8 +24,6 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, int64_t BLOCK_SIZE_K, int64_t bit); #endif -bool moe_permute_unpermute_supported(); - void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor); \ No newline at end of file diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index 13aecd8007a4..6edbcf5800f1 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -5,131 +5,6 @@ #include "permute_unpermute_kernels/dispatch.h" #include "core/registration.h" -// moe_permute kernels require at least CUDA 12.0 -#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) - -void moe_permute( - const torch::Tensor& input, // [n_token, hidden] - const torch::Tensor& topk_weights, //[n_token, topk] - torch::Tensor& topk_ids, // [n_token, topk] - const torch::Tensor& token_expert_indices, // [n_token, topk] - const std::optional& expert_map, // [n_expert] - int64_t n_expert, int64_t n_local_expert, int64_t topk, - const std::optional& align_block_size, - torch::Tensor& - permuted_input, // [topk * n_token/align_block_size_m, hidden] - torch::Tensor& expert_first_token_offset, // [n_local_expert + 1] - torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] - torch::Tensor& m_indices) { // [align_expand_m] - TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float, - "topk_weights must be float32"); - TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long, - "expert_first_token_offset must be int64"); - TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, - "topk_ids must be int32"); - TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int, - "token_expert_indices must be int32"); - TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int, - "src_row_id2dst_row_id_map must be int32"); - TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, - "expert_first_token_offset shape != n_local_expert+1") - TORCH_CHECK( - src_row_id2dst_row_id_map.sizes() == token_expert_indices.sizes(), - "token_expert_indices shape must be same as src_row_id2dst_row_id_map"); - auto n_token = input.sizes()[0]; - auto n_hidden = input.sizes()[1]; - auto align_block_size_value = - align_block_size.has_value() ? align_block_size.value() : -1; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const long sorter_size = - CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert); - auto sort_workspace = torch::empty( - {sorter_size}, - torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); - auto permuted_experts_id = torch::empty_like(topk_ids); - auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map); - auto align_expert_first_token_offset = - torch::zeros_like(expert_first_token_offset); - - CubKeyValueSorter sorter{}; - int64_t* valid_num_ptr = nullptr; - // pre-process kernel for expert-parallelism: - // no local expert id plus "n_expert" offset for priority to local expert - // map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1] - // For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id - // [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids - // and map global expert id [2, 3] to local_expert id [0, 1] and map global - // expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map - // operation is to make local expert high priority in following sort topk_ids - // and scan local expert_first_token_offset for each ep rank for next group - // gemm. - if (expert_map.has_value()) { - const int* expert_map_ptr = get_ptr(expert_map.value()); - valid_num_ptr = - get_ptr(expert_first_token_offset) + n_local_expert; - preprocessTopkIdLauncher(get_ptr(topk_ids), n_token * topk, - expert_map_ptr, n_expert, stream); - } - // expert sort topk expert id and scan expert id get expert_first_token_offset - sortAndScanExpert(get_ptr(topk_ids), get_ptr(token_expert_indices), - get_ptr(permuted_experts_id), - get_ptr(dst_row_id2src_row_id_map), - get_ptr(expert_first_token_offset), n_token, - n_expert, n_local_expert, topk, sorter, - get_ptr(sort_workspace), stream); - - // dispatch expandInputRowsKernelLauncher - MOE_DISPATCH(input.scalar_type(), [&] { - expandInputRowsKernelLauncher( - get_ptr(input), get_ptr(permuted_input), - get_ptr(topk_weights), get_ptr(permuted_experts_id), - get_ptr(dst_row_id2src_row_id_map), - get_ptr(src_row_id2dst_row_id_map), - get_ptr(expert_first_token_offset), n_token, valid_num_ptr, - n_hidden, topk, n_local_expert, align_block_size_value, stream); - }); - - // get m_indices and update expert_first_token_offset with align block - getMIndices(get_ptr(expert_first_token_offset), - get_ptr(align_expert_first_token_offset), - get_ptr(m_indices), n_local_expert, align_block_size_value, - stream); - if (align_block_size.has_value()) { - // update align_expert_first_token_offset - expert_first_token_offset.copy_(align_expert_first_token_offset); - } -} - -void moe_unpermute( - const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden] - const torch::Tensor& topk_weights, //[n_token, topk] - const torch::Tensor& topk_ids, // [n_token, topk] - const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] - const torch::Tensor& expert_first_token_offset, // [n_local_expert+1] - int64_t n_expert, int64_t n_local_expert, int64_t topk, - torch::Tensor& hidden_states // [n_token, hidden] -) { - TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(), - "topk_ids shape must be same as src_row_id2dst_row_id_map"); - TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, - "topk_ids must be int32"); - TORCH_CHECK( - permuted_hidden_states.scalar_type() == hidden_states.scalar_type(), - "topk_ids dtype must be same as src_row_id2dst_row_id_map"); - auto n_token = hidden_states.size(0); - auto n_hidden = hidden_states.size(1); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const int64_t* valid_ptr = - get_ptr(expert_first_token_offset) + n_local_expert; - MOE_DISPATCH(hidden_states.scalar_type(), [&] { - finalizeMoeRoutingKernelLauncher( - get_ptr(permuted_hidden_states), - get_ptr(hidden_states), get_ptr(topk_weights), - get_ptr(src_row_id2dst_row_id_map), get_ptr(topk_ids), - n_token, n_hidden, topk, valid_ptr, stream); - }); -} - template __global__ void shuffleInputRowsKernel(const T* input, const int32_t* dst2src_map, T* output, @@ -216,46 +91,3 @@ void shuffle_rows(const torch::Tensor& input_tensor, }); } } - -#else - -void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, - torch::Tensor& topk_ids, - const torch::Tensor& token_expert_indices, - const std::optional& expert_map, - int64_t n_expert, int64_t n_local_expert, int64_t topk, - const std::optional& align_block_size, - torch::Tensor& permuted_input, - torch::Tensor& expert_first_token_offset, - torch::Tensor& src_row_id2dst_row_id_map, - torch::Tensor& m_indices) { - TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); -} - -void moe_unpermute(const torch::Tensor& input, - const torch::Tensor& topk_weights, torch::Tensor& topk_ids, - const torch::Tensor& token_expert_indices, - const std::optional& expert_map, - int64_t n_expert, int64_t n_local_expert, int64_t topk, - const std::optional& align_block_size, - torch::Tensor& permuted_input, - torch::Tensor& expert_first_token_offset, - torch::Tensor& src_row_id2dst_row_id_map, - torch::Tensor& m_indices) { - TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); -} - -#endif - -bool moe_permute_unpermute_supported() { -#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) - return true; -#else - return false; -#endif -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("moe_permute", &moe_permute); - m.impl("moe_unpermute", &moe_unpermute); -} diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 97df311d0440..5adaad0d5411 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -55,23 +55,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "int moe_block_size, bool replicate_input, bool apply_weights)" " -> Tensor"); - m.def( - "moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids," - "Tensor token_expert_indices, Tensor? expert_map, int n_expert," - "int n_local_expert," - "int topk, int? align_block_size,Tensor! permuted_input, Tensor! " - "expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! " - "m_indices)->()"); - - m.def( - "moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights," - "Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor " - "expert_first_token_offset, int n_expert, int n_local_expert,int " - "topk, Tensor! hidden_states)->()"); - - m.def("moe_permute_unpermute_supported() -> bool"); - m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported); - // Row shuffle for MoE m.def( "shuffle_rows(Tensor input_tensor, Tensor dst2src_map, Tensor! " diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py deleted file mode 100644 index 7cc83b512c8b..000000000000 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ /dev/null @@ -1,226 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for the MOE permute/unpermute kernel - -Run `pytest tests/kernels/test_moe_permute_unpermute.py`. -""" - -from typing import Optional - -import numpy as np -import pytest -import torch - -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.layer import determine_expert_map -from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - moe_permute, moe_permute_unpermute_supported, moe_unpermute) -from vllm.platforms import current_platform - -NUM_EXPERTS = [16, 64] -TOP_KS = [2, 4, 6, 8] -EP_SIZE = [1, 4, 16] -current_platform.seed_everything(0) - - -def torch_permute(hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - topk: int, - n_expert: int, - n_local_expert: int, - start_expert: int, - expert_map: Optional[torch.Tensor] = None, - align_block_size: Optional[int] = None, - fill_invalid_expert: int = -1) -> list[torch.Tensor]: - n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1] - if expert_map is not None: - is_local_expert = (expert_map[topk_ids] != -1) - not_local_expert = (expert_map[topk_ids] == -1) - topk_ids = is_local_expert * ( - topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert) - - sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), - stable=True) - dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices] - - expert_first_token_offset = torch.zeros(n_local_expert + 1, - dtype=torch.int64, - device="cuda") - idx = 0 - for i in range(0, n_local_expert): - cnt = 0 - while idx < sorted_topk_ids.numel() and sorted_topk_ids[idx] == i: - cnt += 1 - idx += 1 - expert_first_token_offset[i + 1] = expert_first_token_offset[i] + cnt - - _, src2dst_idx = torch.sort(dst_row_id2src_row_id_map) - valid_row_idx = [] - if align_block_size is None: - - permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map % - n_token, ...] - permuted_row_size = permuted_hidden_states.shape[0] - m_indices = torch.empty(permuted_row_size, - device="cuda", - dtype=torch.int32).fill_(fill_invalid_expert) - for i in range(1, n_local_expert + 1): - first_token_offset = expert_first_token_offset[i - 1] - last_token_offset = expert_first_token_offset[i] - m_indices[first_token_offset:last_token_offset] = i - 1 - src_row_id2dst_row_id_map = torch.arange( - 0, n_token * topk, device="cuda", - dtype=torch.int32)[src2dst_idx].reshape((n_token, topk)) - valid_row_idx += [i for i in range(expert_first_token_offset[-1])] - return [ - permuted_hidden_states, expert_first_token_offset, - src_row_id2dst_row_id_map, m_indices, valid_row_idx - ] - else: - permuted_row_size = (topk * n_token + n_expert * - (align_block_size - 1) + align_block_size - - 1) // align_block_size * align_block_size - permuted_hidden_states = torch.empty((permuted_row_size, n_hidden), - device="cuda", - dtype=hidden_states.dtype) - align_src_row_id2dst_row_id = torch.empty(n_token * topk, - device="cuda", - dtype=torch.int32) - align_expert_first_token_offset = torch.zeros_like( - expert_first_token_offset) - m_indices = torch.empty(permuted_row_size, - device="cuda", - dtype=torch.int32).fill_(fill_invalid_expert) - # get align_permuted_hidden_states, - # valid row_idx and align_expert_first_token_offset - for i in range(1, n_local_expert + 1): - first_token_offset = expert_first_token_offset[i - 1] - last_token_offset = expert_first_token_offset[i] - n_token_in_expert = last_token_offset - first_token_offset - align_expert_first_token_offset[ - i] = align_expert_first_token_offset[ - i - 1] + (n_token_in_expert + align_block_size - - 1) // align_block_size * align_block_size - align_first_token_offset = align_expert_first_token_offset[i - 1] - align_last_token_offset = align_expert_first_token_offset[i] - dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[ - first_token_offset:first_token_offset + - n_token_in_expert] % n_token - # store token in current expert with align_first_token_offset - permuted_hidden_states[align_first_token_offset:\ - align_first_token_offset+n_token_in_expert,\ - ...] = hidden_states[\ - dst_row_id2src_row_id_in_expert, ...] - # set current expert m_indices - m_indices[align_first_token_offset:align_last_token_offset] = i - 1 - valid_row_idx += [ - i for i in range(align_first_token_offset, - align_first_token_offset + n_token_in_expert) - ] - # get align_src_row_id2dst_row_id - for i in range(n_token * topk): - eid = sorted_topk_ids[i] - if (eid >= n_local_expert): - # check token not in local expert - align_src_row_id2dst_row_id[ - i] = align_expert_first_token_offset[-1] - continue - first_token_offset = expert_first_token_offset[eid] - align_first_token_offset = align_expert_first_token_offset[eid] - token_offset = i - first_token_offset - align_src_row_id2dst_row_id[ - i] = align_first_token_offset + token_offset - align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[\ - src2dst_idx].reshape((n_token, topk)) - return [ - permuted_hidden_states, align_expert_first_token_offset, - align_src_row_id2dst_row_id, m_indices, valid_row_idx - ] - - -def torch_unpermute(permuted_hidden_states: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - src_row_id2dst_row_id_map: torch.Tensor, - valid_row_idx: torch.Tensor, topk: int, - n_expert: int) -> torch.Tensor: - # ignore invalid row - mask = torch.zeros(permuted_hidden_states.shape[0], - dtype=bool, - device="cuda") - mask[valid_row_idx] = True - permuted_hidden_states[~mask] = 0 - idx = src_row_id2dst_row_id_map.flatten()[ - token_expert_indices.flatten()].reshape(token_expert_indices.shape) - output = permuted_hidden_states[idx, ...] * topk_weights[..., None] - output = output.sum(dim=1).to(permuted_hidden_states.dtype) - return output - - -@pytest.mark.parametrize("n_token", [1, 33, 64, 222, 1024, 2048, 3000, 5000]) -@pytest.mark.parametrize("n_hidden", [2048, 4096, 7168]) -@pytest.mark.parametrize("n_expert", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("ep_size", EP_SIZE) -@pytest.mark.parametrize("align_block_size", [None, 128]) -def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, - n_expert: int, ep_size: int, dtype: torch.dtype, - align_block_size: Optional[int]): - if not moe_permute_unpermute_supported(): - pytest.skip("moe_permute_unpermute is not supported on this platform.") - fill_invalid_expert = 0 - ep_rank = np.random.randint(0, ep_size) - expert_map = None - n_local_expert = n_expert - if (ep_size != 1): - n_local_expert, expert_map = determine_expert_map( - ep_size, ep_rank, n_expert) - expert_map = expert_map.cuda() - start_expert = n_local_expert * ep_rank - current_platform.seed_everything(0) - hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype) - gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype) - topk_weights, topk_ids, token_expert_indices = fused_topk( - hidden_states, gating_output, topk, False) - gold0, gold1, gold2, gold3, valid_row_idx = torch_permute( - hidden_states, - topk_ids, - token_expert_indices, - topk, - n_expert, - n_local_expert, - start_expert, - expert_map=expert_map, - align_block_size=align_block_size, - fill_invalid_expert=fill_invalid_expert) - - result0, result1, result2, result3 = moe_permute( - hidden_states, topk_weights, topk_ids, token_expert_indices, topk, - n_expert, n_local_expert, expert_map, align_block_size, - fill_invalid_expert) - - # check expert_first_token_offset - torch.testing.assert_close(gold1, result1, atol=0, rtol=0) - # check src_row_id2dst_row_id_map - torch.testing.assert_close(gold2, result2, atol=0, rtol=0) - # check mindice - torch.testing.assert_close(gold3, result3, atol=0, rtol=0) - # check permuted_hidden_states, only valid token - torch.testing.assert_close(gold0[valid_row_idx], - result0[valid_row_idx], - atol=0, - rtol=0) - - # add a random tensor to simulate group gemm - result0 = 0.5 * result0 + torch.randn_like(result0) - - result4 = moe_unpermute(result0, topk_weights, topk_ids, result2, result1, - topk, n_expert, n_local_expert) - gold4 = torch_unpermute(result0, topk_weights, topk_ids, - token_expert_indices, result2, valid_row_idx, topk, - n_local_expert) - - # check unpermuted hidden - torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0) diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py deleted file mode 100644 index 20ee0d9f780a..000000000000 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ /dev/null @@ -1,190 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional - -import torch - -from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) -from vllm.model_executor.layers.fused_moe.utils import _fp8_perm - - -def _moe_permute( - curr_hidden_states: torch.Tensor, - a1q_scale: Optional[torch.Tensor], - curr_topk_ids: torch.Tensor, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - block_m: int, -) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, - torch.Tensor]: - """ - Determine the sorted_token_ids, expert_ids for the given problem size. - Permute the hidden states and scales according to `sorted_token_ids`. - """ - top_k_num = curr_topk_ids.size(1) - - tokens_in_chunk = curr_hidden_states.size(0) - - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, - block_m, - global_num_experts, - expert_map, - pad_sorted_ids=True)) - - inv_perm: Optional[torch.Tensor] = None - - num_tokens = top_k_num * tokens_in_chunk - expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) - inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] - - # Permute according to sorted token ids. - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - - curr_hidden_states = _fp8_perm(curr_hidden_states, - sorted_token_ids // top_k_num) - - if a1q_scale is not None: - a1q_scale = a1q_scale[sorted_token_ids // top_k_num] - - return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) - - -def _moe_unpermute_and_reduce( - out: torch.Tensor, - curr_hidden: torch.Tensor, - inv_perm: Optional[torch.Tensor], - topk_weight: torch.Tensor, - apply_router_weight_on_input: bool, -) -> None: - """ - Unpermute the final result and apply topk_weights, then perform the final - reduction on the hidden states. - """ - M, topk = topk_weight.size() - K = curr_hidden.size(-1) - if inv_perm is not None: - curr_hidden = curr_hidden[inv_perm, ...] - curr_hidden = curr_hidden.view(-1, topk, K) - if not apply_router_weight_on_input: - curr_hidden.mul_(topk_weight.view(M, -1, 1)) - ops.moe_sum(curr_hidden, out) - - -def moe_permute( - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - topk: int, - n_expert: int, - n_local_expert: int, - expert_map: Optional[torch.Tensor] = None, - align_block_size: Optional[int] = None, - fill_invalid_expert: int = -1 -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - This function expands and permutes activation to gather uncontinuous tokens - for each expert. - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - topk_weights (torch.Tensor): topk expert route weight for each token. - - topk_ids (torch.Tensor): topk expert route id for each token. - - token_expert_indices (torch.Tensor): indice for expanded hidden. - - topk (int): The number of top-k experts to select. - - n_expert (int): The number of expert. - - n_local_expert (int): The number of expert in current EP rank. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert - parallel shard. - - align_block_size (Optional[int]): align group gemm block size for deepgemm - - fill_invalid_expert(int): fill expert id in m_indices for invalid expert - to workaround DeepGemm unsupported -1 in m_indices - Returns: - - permuted_hidden_states (torch.Tensor): permuted activation. - - expert_first_token_offset (torch.Tensor): offset of the first token - of each expert for standard grouped gemm. if enable 'align_block_size' - expert_first_token_offset will align up to 'align_block_size'. - - src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute. - - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records - the group which the j-th row of the LHS belong to.` - """ - n_token, n_hidden = hidden_states.size() - assert (n_hidden * hidden_states.element_size() - ) % 16 == 0, "permue kernel need hidden dim align to 16B" - permuted_row_size = n_token * topk - if align_block_size is not None: - permuted_row_size = (permuted_row_size + n_expert * - (align_block_size - 1) + align_block_size - - 1) // align_block_size * align_block_size - - permuted_hidden_states = torch.empty( - (permuted_row_size, n_hidden), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - m_indices = torch.full((permuted_row_size, ), - fill_invalid_expert, - dtype=torch.int32, - device=hidden_states.device) - expert_first_token_offset = torch.empty(n_local_expert + 1, - dtype=torch.int64, - device=hidden_states.device) - src_row_id2dst_row_id_map = torch.empty((n_token, topk), - dtype=torch.int32, - device=hidden_states.device) - torch.ops._moe_C.moe_permute(hidden_states, topk_weights, topk_ids, - token_expert_indices, expert_map, n_expert, - n_local_expert, topk, align_block_size, - permuted_hidden_states, - expert_first_token_offset, - src_row_id2dst_row_id_map, m_indices) - return (permuted_hidden_states, expert_first_token_offset, - src_row_id2dst_row_id_map, m_indices) - - -def moe_unpermute( - permuted_hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - src_row_id2dst_row_id_map: torch.Tensor, - expert_first_token_offset: torch.Tensor, - topk: int, - n_expert: int, - n_local_expert: int, -) -> torch.Tensor: - """ - This function expands and permutes activation to gathering uncontinuous - tokens for each expert. - Parameters: - - permuted_hidden_states (torch.Tensor): permuted activation. - - topk_weights (torch.Tensor): topk expert route weight for each token. - - topk_ids (torch.Tensor): topk expert route id for each token. - - expert_first_token_offset (torch.Tensor): offset of the first token - of each expert for grouped gemm. - - topk (int): The number of top-k experts to select. - - n_expert (int): The number of expert. - - n_local_expert (int): The number of expert in current EP rank. - Returns: - - hidden_states (torch.Tensor): The reduced and unpermuted activation - tensor. - """ - n_token, n_hidden = topk_weights.size(0), permuted_hidden_states.size(-1) - assert (n_hidden * permuted_hidden_states.element_size() - ) % 16 == 0, "unpermue kernel need hidden dim align to 16B" - hidden_states = torch.empty((n_token, n_hidden), - dtype=permuted_hidden_states.dtype, - device=permuted_hidden_states.device) - - torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, - topk_ids, src_row_id2dst_row_id_map, - expert_first_token_offset, n_expert, - n_local_expert, topk, hidden_states) - return hidden_states - - -def moe_permute_unpermute_supported(): - return torch.ops._moe_C.moe_permute_unpermute_supported()