Skip to content

[Feature][EPLB] Add eplb support for Qwen2 #21035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 132 additions & 45 deletions vllm/model_executor/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import typing
from collections.abc import Iterable
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

import torch
import torch.nn.functional as F
Expand All @@ -34,8 +35,9 @@

from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_ep_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand All @@ -50,6 +52,7 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.deepseek_v2 import MixtureOfExperts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To improve modularity and reduce coupling between different model implementations, MixtureOfExperts should be imported directly from its definition file. It is defined as a protocol in vllm.model_executor.models.interfaces. Importing from another model's implementation file (deepseek_v2) creates an unnecessary dependency.

Suggested change
from vllm.model_executor.models.deepseek_v2 import MixtureOfExperts
from vllm.model_executor.models.interfaces import MixtureOfExperts

from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

Expand Down Expand Up @@ -101,23 +104,47 @@ def __init__(
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()

self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts

if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")

self.experts = FusedMoE(num_experts=config.num_experts,
# Load balancing settings.
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.enable_eplb = enable_eplb

self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size

self.physical_expert_start = (self.ep_rank *
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)

self.experts = FusedMoE(num_experts=self.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts")
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)

self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
Expand Down Expand Up @@ -260,6 +287,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -293,7 +321,8 @@ def __init__(
(layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb)
else:
self.mlp = Qwen2MoeMLP(
hidden_size=config.hidden_size,
Expand Down Expand Up @@ -340,7 +369,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

enable_eplb = vllm_config.parallel_config.enable_eplb
self.num_redundant_experts = vllm_config.parallel_config.\
num_redundant_experts
self.vocab_size = config.vocab_size
self.config = config

Expand All @@ -353,7 +384,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
lambda prefix: Qwen2MoeDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=prefix,
enable_eplb=enable_eplb),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -391,15 +423,6 @@ def forward(
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
Expand All @@ -411,9 +434,17 @@ def load_weights(self, weights: Iterable[tuple[str,
("gate_up_proj", "up_proj", 1),
]

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
num_redundant_experts=self.num_redundant_experts)

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
Expand Down Expand Up @@ -443,29 +474,47 @@ def load_weights(self, weights: Iterable[tuple[str,
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if "layers.13.mlp.experts.w2_weight" in name:
pass
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True

# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)

if is_pp_missing_parameter(name_mapped, self):
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
if ((name_mapped.endswith(".bias")
or name_mapped.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True)
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
Expand Down Expand Up @@ -494,20 +543,9 @@ def load_weights(self, weights: Iterable[tuple[str,
return loaded_params


class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
class Qwen2MoeForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):

fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand All @@ -526,6 +564,34 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

# Initialize EPLB-related attributes
self.expert_weights = []

# Set MoE hyperparameters
self.moe_layers: list[FusedMoE] = []
for layer in self.model.layers:
assert isinstance(layer, Qwen2MoeDecoderLayer)
if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock):
self.moe_layers.append(layer.mlp.experts)

self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1

example_moe = None
for layer_idx in range(config.num_hidden_layers):
layer = self.model.layers[layer_idx]
if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock):
example_moe = layer.mlp
break
assert example_moe is not None

self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_redundant_experts = example_moe.n_redundant_experts
self.num_shared_experts = 0

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

Expand Down Expand Up @@ -554,5 +620,26 @@ def load_weights(self, weights: Iterable[tuple[str,
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
"""
Register the EPLB state in the MoE model.

Args:
expert_load_view: A view of the expert load metrics tensor.
logical_to_physical_map: Mapping from logical to physical experts.
logical_replica_count: Count of replicas for each logical expert.
"""
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)