diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 84bae87804c1..311498dba6da 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -24,8 +24,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" +import typing from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -34,8 +35,9 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_ep_group from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -50,6 +52,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.deepseek_v2 import MixtureOfExperts from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -101,23 +104,47 @@ def __init__( config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + enable_eplb: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}.") - self.experts = FusedMoE(num_experts=config.num_experts, + # Load balancing settings. + vllm_config = get_current_vllm_config() + parallel_config = vllm_config.parallel_config + self.enable_eplb = enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + + self.experts = FusedMoE(num_experts=self.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, - prefix=f"{prefix}.experts") + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) self.gate = ReplicatedLinear(config.hidden_size, config.num_experts, @@ -260,6 +287,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + enable_eplb: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -293,7 +321,8 @@ def __init__( (layer_idx + 1) % config.decoder_sparse_step == 0): self.mlp = Qwen2MoeSparseMoeBlock(config=config, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb) else: self.mlp = Qwen2MoeMLP( hidden_size=config.hidden_size, @@ -340,7 +369,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - + enable_eplb = vllm_config.parallel_config.enable_eplb + self.num_redundant_experts = vllm_config.parallel_config.\ + num_redundant_experts self.vocab_size = config.vocab_size self.config = config @@ -353,7 +384,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lambda prefix: Qwen2MoeDecoderLayer(config=config, cache_config=cache_config, quant_config=quant_config, - prefix=prefix), + prefix=prefix, + enable_eplb=enable_eplb), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -391,15 +423,6 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -411,9 +434,17 @@ def load_weights(self, weights: Iterable[tuple[str, ("gate_up_proj", "up_proj", 1), ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts) + params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() - expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -443,29 +474,47 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight, shard_id) break else: + is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) - if "layers.13.mlp.experts.w2_weight" in name: - pass - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): continue # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) + if ((name_mapped.endswith(".bias") + or name_mapped.endswith("_bias")) and name not in params_dict): continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + name = name_mapped + break else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue # Skip loading extra bias for GPTQ models. if ((name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict): @@ -494,20 +543,9 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class Qwen2MoeForCausalLM(nn.Module, SupportsPP): +class Qwen2MoeForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): fall_back_to_pt_during_load = False - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -526,6 +564,34 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + # Initialize EPLB-related attributes + self.expert_weights = [] + + # Set MoE hyperparameters + self.moe_layers: list[FusedMoE] = [] + for layer in self.model.layers: + assert isinstance(layer, Qwen2MoeDecoderLayer) + if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock): + self.moe_layers.append(layer.mlp.experts) + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + + example_moe = None + for layer_idx in range(config.num_hidden_layers): + layer = self.model.layers[layer_idx] + if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock): + example_moe = layer.mlp + break + assert example_moe is not None + + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_redundant_experts = example_moe.n_redundant_experts + self.num_shared_experts = 0 + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -554,5 +620,26 @@ def load_weights(self, weights: Iterable[tuple[str, loader = AutoWeightsLoader(self) return loader.load_weights(weights) - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + """ + Register the EPLB state in the MoE model. + + Args: + expert_load_view: A view of the expert load metrics tensor. + logical_to_physical_map: Mapping from logical to physical experts. + logical_replica_count: Count of replicas for each logical expert. + """ + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + )