diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index eca37a09058a..144e471ea443 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -316,6 +316,7 @@ Specified using `--task generate`. | `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | | ✅︎ | ✅︎ | | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | | `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | diff --git a/tests/models/registry.py b/tests/models/registry.py index 1207a928c92f..9d3fc8a1b1cf 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -141,6 +141,8 @@ def check_available_online( trust_remote_code=True), "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True), + "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", + trust_remote_code=True), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B", extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py new file mode 100644 index 000000000000..325ba7bbad83 --- /dev/null +++ b/vllm/model_executor/models/bailing_moe.py @@ -0,0 +1,530 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/inclusionAI/Ling/blob/master/models/modeling_bailing_moe.py +# Copyright 2023 The vLLM team. +# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BailingMoE model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from vllm.attention import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class BailingAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.total_kv_heads = config.num_key_value_heads + tp_size = get_tensor_model_parallel_world_size() + + assert self.total_num_heads % tp_size == 0 + assert self.total_kv_heads % tp_size == 0 + assert self.total_num_heads >= self.total_kv_heads + + self.num_heads = self.total_num_heads // tp_size + self.head_dim = config.head_dim or (self.hidden_size // + self.total_num_heads) + self.q_size_per_rank = self.head_dim * self.num_heads + + self.num_kv_heads = self.total_kv_heads // tp_size + self.kv_size_per_rank = self.num_kv_heads * self.head_dim + self.scale = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_kv_heads, + bias=(config.use_bias or config.use_qkv_bias), + quant_config=quant_config, + prefix=f"{prefix}.query_key_value", + ) + + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn") + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + is_neox_style=True, + rope_scaling=config.rope_scaling, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.split([ + self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank + ], + dim=-1) + + q, k = self.rotary_emb(position_ids, q, k) + + context_layer = self.attn(q, k, v) + + attn_output, _ = self.dense(context_layer) + return attn_output + + +class BailingMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: Optional[bool] = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + config.hidden_size, + [intermediate_size] * 2, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + config.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class BailingMoE(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: Optional[bool] = True, + prefix: str = "", + ): + super().__init__() + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_expert_prob = config.norm_topk_prob + self.hidden_size = config.hidden_size + self.quant_config = quant_config + self.num_shared_experts = config.num_shared_experts + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(self.hidden_size, + self.num_experts, + bias=False, + quant_config=None) + + self.experts = FusedMoE(num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_expert_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + if self.num_shared_experts > 0: + intermediate_size = (config.moe_intermediate_size * + self.num_shared_experts) + self.shared_experts = BailingMLP( + intermediate_size=intermediate_size, + config=config, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts") + else: + self.shared_experts = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_size) + if self.num_shared_experts > 0: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if self.num_shared_experts > 0: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + + +class BailingMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) + self.attention = BailingAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attention") + self.post_attention_layernorm = RMSNorm(hidden_size, + eps=config.rms_norm_eps) + self.mlp = BailingMoE(intermediate_size, + config, + quant_config, + True, + prefix=f"{prefix}.mlp") + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.attention( + hidden_states=hidden_states, + position_ids=position_ids, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class BailingMoeModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + self.embed_dim = config.hidden_size + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.word_embeddings = VocabParallelEmbedding( + self.vocab_size, self.embed_dim) + else: + self.word_embeddings = PPMissingLayer() + + self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: BailingMoeBlock( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers") + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + hidden_states, + position_ids, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if self.config.norm_head and "lm_head.weight" in name: + loaded_weight = F.normalize(loaded_weight, + dim=0, + p=2, + eps=1e-7) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class BailingMoeForCausalLM(nn.Module, SupportsPP): + + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config = config + self.quant_config = quant_config + self.max_position_embeddings = config.max_position_embeddings + self.model = BailingMoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = (self.word_embeddings if config.tie_word_embeddings + else ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config)) + self.logits_processor = LogitsProcessor(config.vocab_size) + else: + self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b7d4789549aa..79190860ac94 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -41,6 +41,7 @@ "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-13b, lower case 'c' in the class name "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), + "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"), "BambaForCausalLM": ("bamba", "BambaForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),