diff --git a/CMakeLists.txt b/CMakeLists.txt index edc64f87730..857b90d38cc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,13 @@ endif() # find_package(Torch REQUIRED) +# +# Ignore nvToolsExt for cuda-12.9 +# +if (NOT TARGET CUDA::nvToolsExt) + add_library(CUDA::nvToolsExt INTERFACE IMPORTED) +endif() + # Supported NVIDIA architectures. # This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 8fd8b8220cf..c5e97280e54 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -357,6 +357,7 @@ Specified using `--task generate`. | `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | | `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | | +| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | | | `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 12150cf2a82..63300f2f8da 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -114,6 +114,10 @@ Models that combine Mamba-2 layers with standard attention layers are also suppo these models currently require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention backend in V1. +Hybrid models that share similar sub-components to Mamba2 layers, e.g. ShortConv layers in LFM2, are also supported. +As above, they also require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention +backend in V1. + #### Encoder-Decoder Models Models requiring cross-attention between separate encoder and decoder (e.g., `BartForConditionalGeneration`, `MllamaForConditionalGeneration`) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index eba14e64553..73e8545c7be 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -35,6 +35,7 @@ "nvidia/Nemotron-H-8B-Base-8K", "ibm-granite/granite-4.0-tiny-preview", "tiiuae/Falcon-H1-0.5B-Base", + "LiquidAI/LFM2-1.2B" ] HF_UNSUPPORTED_MODELS = [ @@ -53,17 +54,21 @@ ] V1_SUPPORTED_MODELS = [ - "mistralai/Mamba-Codestral-7B-v0.1", - "ibm-ai-platform/Bamba-9B-v1", - "Zyphra/Zamba2-1.2B-instruct", - "nvidia/Nemotron-H-8B-Base-8K", - "ibm-granite/granite-4.0-tiny-preview", - "tiiuae/Falcon-H1-0.5B-Base", + "mistralai/Mamba-Codestral-7B-v0.1", "ibm-ai-platform/Bamba-9B-v1", + "Zyphra/Zamba2-1.2B-instruct", "nvidia/Nemotron-H-8B-Base-8K", + "ibm-granite/granite-4.0-tiny-preview", "tiiuae/Falcon-H1-0.5B-Base", + "LiquidAI/LFM2-1.2B" ] # Avoid OOM MAX_NUM_SEQS = 4 +# To be removed once implemented +V1_HYBRID_UNSUPPORTED_ARGS = { + "enforce_eager": True, + "enable_prefix_caching": False, +} + @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @@ -104,8 +109,7 @@ def test_models( m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, - enforce_eager=True, - enable_prefix_caching=False) as vllm_model: + **V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) else: @@ -183,6 +187,7 @@ def test_chunked_prefill( with vllm_runner(model, enable_chunked_prefill=True, + **V1_HYBRID_UNSUPPORTED_ARGS, max_num_batched_tokens=max_num_batched_tokens, max_num_seqs=max_num_seqs) as vllm_model: chunked = vllm_model.generate_greedy_logprobs(example_prompts, @@ -190,6 +195,7 @@ def test_chunked_prefill( with vllm_runner(model, enable_chunked_prefill=False, + **V1_HYBRID_UNSUPPORTED_ARGS, max_num_seqs=max_num_seqs) as vllm_model: non_chunked = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) @@ -230,6 +236,7 @@ def test_chunked_prefill_with_parallel_sampling( # forces prefill chunks with decoding max_num_batched_tokens=MAX_NUM_SEQS * 3, max_num_seqs=MAX_NUM_SEQS, + **V1_HYBRID_UNSUPPORTED_ARGS, ) as vllm_model: vllm_model.generate(example_prompts, sampling_params) @@ -254,7 +261,7 @@ def test_mamba_cache_cg_padding( example_prompts.append(example_prompts[0]) try: - with vllm_runner(model) as vllm_model: + with vllm_runner(model, **V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) except RuntimeError: pytest.fail( @@ -274,7 +281,9 @@ def test_models_preemption_recompute( """ Tests that outputs are identical with and w/o preemptions (recompute). """ - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + with vllm_runner(model, + max_num_seqs=MAX_NUM_SEQS, + **V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model: scheduler = vllm_model.model.llm_engine.scheduler[0] scheduler.ENABLE_ARTIFICIAL_PREEMPT = True preempt_vllm_outputs = vllm_model.generate_greedy( @@ -307,7 +316,9 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( a single step. """ try: - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + with vllm_runner(model, + max_num_seqs=MAX_NUM_SEQS, + **V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model: vllm_model.generate_greedy([example_prompts[0]] * 100, 10) except ValueError: pytest.fail("Hybrid inner state wasn't cleaned up properly between" @@ -327,7 +338,9 @@ def test_state_cleanup( If its not cleaned, an error would be expected. """ try: - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + with vllm_runner(model, + max_num_seqs=MAX_NUM_SEQS, + **V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model: for _ in range(10): vllm_model.generate_greedy([example_prompts[0]] * 100, 1) except ValueError: @@ -343,13 +356,17 @@ def test_multistep_correctness( model: str, max_tokens: int, ) -> None: - with vllm_runner(model, num_scheduler_steps=8, - max_num_seqs=2) as vllm_model: + with vllm_runner(model, + num_scheduler_steps=8, + max_num_seqs=2, + **V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model: vllm_outputs_multistep = vllm_model.generate_greedy( example_prompts, max_tokens) - with vllm_runner(model, num_scheduler_steps=1, - max_num_seqs=2) as vllm_model: + with vllm_runner(model, + num_scheduler_steps=1, + max_num_seqs=2, + **V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model: vllm_outputs_single_step = vllm_model.generate_greedy( example_prompts, max_tokens) @@ -372,13 +389,17 @@ def test_distributed_correctness( max_tokens: int, num_logprobs: int, ) -> None: - with vllm_runner(model, tensor_parallel_size=1, - max_num_seqs=2) as vllm_model: + with vllm_runner(model, + tensor_parallel_size=1, + max_num_seqs=2, + **V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model: vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - with vllm_runner(model, tensor_parallel_size=2, - max_num_seqs=2) as vllm_model: + with vllm_runner(model, + tensor_parallel_size=2, + max_num_seqs=2, + **V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model: vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) diff --git a/tests/models/registry.py b/tests/models/registry.py index 56ae501021f..ff00c86e5f3 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -207,6 +207,8 @@ def check_available_online( "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini", extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501 + "Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B", + min_transformers_version="4.54"), "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct", extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501 "hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501 diff --git a/vllm/config.py b/vllm/config.py index f94c08c3253..0a226226758 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1379,6 +1379,13 @@ def get_num_layers_by_block_type( # Hybrid model Jamba layers_block_type_value = getattr(self.hf_config, "layers_block_type", None) + + # Hybrid models in transformers >= 4.54.0.dev0 + # populate a `layer_types` attribute + if layers_block_type_value is None: + layers_block_type_value = getattr(self.hf_text_config, + "layer_types", None) + if layers_block_type_value is not None: if hasattr(self.hf_text_config, "model_type") and (self.hf_text_config.model_type @@ -1388,8 +1395,14 @@ def get_num_layers_by_block_type( for t in layers_block_type_value[start:end]) else: return self.get_num_layers(parallel_config) - return sum(t == block_type.value - for t in layers_block_type_value[start:end]) + + # Support with hybrid transformers configs >= 4.54.0.dev0 + if attn_block_type: + return sum(t in ("full_attention", "attention") + for t in layers_block_type_value[start:end]) + else: + return sum(t == block_type.value + for t in layers_block_type_value[start:end]) # Hybrid model Minimax attn_type_list = getattr(self.hf_config, "attn_type_list", None) @@ -1634,9 +1647,10 @@ class CacheConfig: checkpoint if available. Otherwise, the scales will default to 1.0.""" cpu_kvcache_space_bytes: Optional[int] = None """(CPU backend only) CPU key-value cache space.""" - mamba_page_size_padded: Optional[int] = None - """ Optional override for mamba page size; used by hybrid mamba/attention - models to ensure exact alignment with attention page size.""" + static_cache_page_size_padded: Optional[int] = None + """ Optional override for static cache page size; used by hybrid static + cache (e.g. mamba, short-conv) / attention models to ensure exact alignment + with attention page size.""" # Will be set after profiling. num_gpu_blocks: Optional[int] = field(default=None, init=False) @@ -4827,13 +4841,14 @@ def try_verify_and_update_config(self): return from vllm.model_executor.models.config import ( - MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig) + MODELS_CONFIG_MAP, HybridAttentionStaticCacheModelConfig) cls = MODELS_CONFIG_MAP.get(architecture, None) if cls is not None: cls.verify_and_update_config(self) if self.model_config.is_hybrid: - HybridAttentionMambaModelConfig.verify_and_update_config(self) + HybridAttentionStaticCacheModelConfig.verify_and_update_config( + self) if self.model_config.task == "classify": # Maybe convert ForCausalLM into ForSequenceClassification model. diff --git a/vllm/model_executor/layers/conv.py b/vllm/model_executor/layers/conv.py new file mode 100644 index 00000000000..9ef12bf23bb --- /dev/null +++ b/vllm/model_executor/layers/conv.py @@ -0,0 +1,230 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm import envs +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import get_current_vllm_config +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, + update_metadata) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.models.conv_cache import ConvCacheParams +from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata + + +@CustomOp.register("short_conv") +class ShortConv(CustomOp): + + def __init__(self, config, dim: int, layer_idx: int, prefix: str = ""): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.conv_dim = dim + self.L_cache = config.conv_L_cache + self.bias = config.conv_bias + + self.conv = ColumnParallelLinear( + input_size=self.L_cache, + output_size=dim, + bias=self.bias, + prefix=f"{prefix}.conv1d", + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv.weight.data = self.conv.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear( + input_size=dim, + output_sizes=[dim] * 3, + bias=self.bias, + prefix=f"{prefix}.in_proj", + ) + self.out_proj = RowParallelLinear( + input_size=dim, + output_size=dim, + bias=self.bias, + prefix=f"{prefix}.out_proj", + ) + + if envs.VLLM_USE_V1: + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + # The inner tuple is (conv_state,) + self.kv_cache = [(torch.tensor([]))] + + # For compatibility with StaticCacheSpec utils + self.chunk_size = 1 + self.prefix = prefix + + def forward_native(self, hidden_states: torch.Tensor, + conv_cache_params: ConvCacheParams) -> torch.Tensor: + pass + + def forward_cuda( + self, + hidden_states: torch.Tensor, + conv_cache_params: ConvCacheParams, + conv_metadata: Mamba2Metadata, + ) -> torch.Tensor: + forward_context = get_forward_context() + # Mamba2Metadata contains metadata necessary for the mamba2 triton + # kernels to operate in continuous batching and in chunked prefill + # modes; they are computed at top-level model forward since they + # stay the same and reused for all mamba layers in the same iteration + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if envs.VLLM_USE_V1: + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + conv_metadata = attn_metadata + assert isinstance(attn_metadata, Mamba2AttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states + else: + conv_state = conv_cache_params.conv_state + state_indices_tensor = conv_cache_params.state_indices_tensor + has_initial_states_p = conv_metadata.has_initial_states + + BCx, _ = self.in_proj(hidden_states) + + B, C, x = BCx.chunk(3, dim=-1) + + conv_weights = self.conv.weight.view(self.conv.weight.size(0), + self.conv.weight.size(2)) + + if envs.VLLM_USE_V1 and attn_metadata is None: + # V1 profile run + Bx = (B * x).contiguous() + hidden_states = C * Bx + contextualized_states, _ = self.out_proj(hidden_states) + return contextualized_states + + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + + # NOTE: V0 put prefill before decode, v1 puts decode before prefill + # Separate prefill and decode by splitting varlen input + # Split along token dimension + if envs.VLLM_USE_V1: + B_d, B_p = torch.split( + B, + [num_decodes, num_prefill_tokens], + dim=0, + ) + C_d, C_p = torch.split( + C, + [num_decodes, num_prefill_tokens], + dim=0, + ) + x_d, x_p = torch.split( + x, + [num_decodes, num_prefill_tokens], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [num_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) + else: + B_p, B_d = torch.split( + B, + [num_prefill_tokens, num_decodes], + dim=0, + ) + C_p, C_d = torch.split( + C, + [num_prefill_tokens, num_decodes], + dim=0, + ) + x_p, x_d = torch.split( + x, + [num_prefill_tokens, num_decodes], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_p, state_indices_tensor_d = torch.split( + state_indices_tensor, + [num_prefills, num_decodes], + dim=0, + ) + query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + + 1] + if has_prefill else None) + + conv_output_list = [] + + if has_prefill: + Bx_p = (B_p * x_p).transpose(0, 1) + if conv_metadata.cu_seqlen is None: + conv_metadata = update_metadata(Bx_p, query_start_loc_p, + conv_metadata) + Bx = causal_conv1d_fn(Bx_p, + conv_weights, + self.conv.bias, + activation=None, + conv_states=conv_state, + has_initial_state=has_initial_states_p, + cache_indices=state_indices_tensor_p, + metadata=conv_metadata, + query_start_loc=query_start_loc_p).transpose( + 0, 1)[:num_prefill_tokens] + + y = C_p * Bx + conv_output_list.append(y) + + if has_decode: + Bx_d = (B_d * x_d).contiguous() + Bx = causal_conv1d_update( + Bx_d, + conv_state, + conv_weights, + self.conv.bias, + activation=None, + conv_state_indices=state_indices_tensor_d) + y = C_d * Bx + if envs.VLLM_USE_V1: + conv_output_list.insert(0, y) + else: + conv_output_list.append(y) + + # Merge prefill and decode outputs before passing to gated MLP + hidden_states = torch.vstack(conv_output_list) + + # Final linear projection + contextualized_states, _ = self.out_proj(hidden_states) + + return contextualized_states + + def get_state_shape(self) -> tuple[tuple[int, ...]]: + world_size = get_tensor_model_parallel_world_size() + # contiguous along 'dim' axis + conv_state_shape = ( + self.L_cache - 1, + divide(self.conv_dim, world_size), + ) + return (conv_state_shape, ) diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index e93d4294a62..5b08491113b 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -436,7 +436,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, embedding_padding_modules = ["lm_head"] @classmethod - def get_mamba_state_shape_from_config( + def get_static_cache_shape_from_config( cls, vllm_config: "VllmConfig", use_v1: bool = True, @@ -524,7 +524,7 @@ def forward(self, LayerBlockType.mamba ) mamba_state_shape = \ - self.get_mamba_state_shape_from_config( + self.get_static_cache_shape_from_config( self.vllm_config, use_v1=False) self.mamba_cache = MambaCacheManager(self.vllm_config, self.lm_head.weight.dtype, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index cb07fe7d9e1..a4279df1d09 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -7,7 +7,7 @@ from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv -from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec +from vllm.v1.kv_cache_interface import FullAttentionSpec, StaticCacheSpec if TYPE_CHECKING: @@ -218,16 +218,16 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: config.max_model_len) -class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): +class HybridAttentionStaticCacheModelConfig(VerifyAndUpdateConfig): @classmethod def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: """ Ensure that page size of attention layers is greater than or - equal to the mamba layers. If not, automatically set the attention - block size to ensure that it is. If the attention page size is - strictly greater than the mamba page size, we pad the mamba page size - to make them equal. + equal to the static cache layers (e.g. mamba, short-conv). If not, + automatically set the attention block size to ensure that it is. If the + attention page size is strictly greater than the static cache page + size, we pad the static cache page size to make them equal. Args: vllm_config: vLLM Config @@ -257,8 +257,10 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: model_config._model_info.architecture)[0] # get mamba page size - mamba_page_size = MambaSpec( - shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), + static_cache_shapes = model_cls.get_static_cache_shape_from_config( + vllm_config) + static_cache_page_size = StaticCacheSpec( + shapes=static_cache_shapes, dtype=kv_cache_dtype, block_size=model_config.max_model_len, ).page_size_bytes @@ -267,7 +269,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # block size to multiple of 16, so let's suggest a value # that would work (note: FA is currently not compatible # with mamba layers, use FlashInfer instead). - attn_block_size = 16 * cdiv(mamba_page_size, + attn_block_size = 16 * cdiv(static_cache_page_size, 16 * attn_page_size_1_token) # override attention block size if either (a) the @@ -285,22 +287,23 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: attn_page_size = \ cache_config.block_size * attn_page_size_1_token - assert attn_page_size >= mamba_page_size + assert attn_page_size >= static_cache_page_size - if attn_page_size == mamba_page_size: + if attn_page_size == static_cache_page_size: # don't need to pad mamba page size return # pad mamba page size to exactly match attention - if (cache_config.mamba_page_size_padded is None - or cache_config.mamba_page_size_padded != attn_page_size): - cache_config.mamba_page_size_padded = (attn_page_size) - mamba_padding_pct = 100 * (attn_page_size - - mamba_page_size) / mamba_page_size + if (cache_config.static_cache_page_size_padded is None or + cache_config.static_cache_page_size_padded != attn_page_size): + cache_config.static_cache_page_size_padded = (attn_page_size) + static_cache_padding_pct = 100 * ( + attn_page_size - + static_cache_page_size) / static_cache_page_size logger.info( - "Padding mamba page size by %.2f%% to ensure " - "that mamba page size and attention page size are " - "exactly equal.", mamba_padding_pct) + "Padding static cache page size by %.2f%% to ensure " + "that static cache page size and attention page size are " + "exactly equal.", static_cache_padding_pct) MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { diff --git a/vllm/model_executor/models/conv_cache.py b/vllm/model_executor/models/conv_cache.py new file mode 100644 index 00000000000..9ba930204af --- /dev/null +++ b/vllm/model_executor/models/conv_cache.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass + +import torch + +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig +from vllm.model_executor.models.constant_size_cache import ConstantSizeCache + + +@dataclass +class ConvCacheParams: + conv_state: torch.Tensor = torch.Tensor() + state_indices_tensor: torch.Tensor = torch.Tensor() + + def at_layer_idx(self, layer_idx): + return ConvCacheParams(self.conv_state[layer_idx], + self.state_indices_tensor) + + +class ConvCacheManager(ConstantSizeCache): + + def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype, + num_conv_layers: int, conv_state_shape: tuple[int, int]): + + max_batch_size = vllm_config.scheduler_config.max_num_seqs + if not vllm_config.model_config.enforce_eager: + max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size) + + # Initialize parent class + super().__init__(max_batch_size) + + # Note(pp): this is for the V0 runner. + # assume conv_state = (dim, state_len). + assert conv_state_shape[0] > conv_state_shape[1] + conv_state = torch.empty(size=(num_conv_layers, max_batch_size) + + (conv_state_shape[1], conv_state_shape[0]), + dtype=dtype, + device="cuda").transpose(-1, -2) + self._lfm2_cache = conv_state + + @property + def cache(self): + return self._lfm2_cache + + def _copy_cache(self, from_index: int, to_index: int): + for cache_t in self.cache: + cache_t[:, to_index].copy_(cache_t[:, from_index], + non_blocking=True) + + def current_run_tensors(self, **kwargs) -> ConvCacheParams: + """ + Return the tensors for the current run's conv state. + """ + cache_tensor, state_indices_tensor = super().current_run_tensors( + **kwargs) + return ConvCacheParams(cache_tensor, state_indices_tensor) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Lfm2 Cache during the CUDA graph + replay runs. + """ + return self._lfm2_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size, + dtype=torch.int32, + device="cuda") diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 7761de224c9..b2d94e44e69 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -515,7 +515,7 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, embedding_padding_modules = ["lm_head"] @classmethod - def get_mamba_state_shape_from_config( + def get_static_cache_shape_from_config( cls, vllm_config: "VllmConfig", use_v1: bool = True, @@ -617,7 +617,7 @@ def forward( if not envs.VLLM_USE_V1: if self.mamba_cache is None: mamba_state_shape = \ - self.get_mamba_state_shape_from_config( + self.get_static_cache_shape_from_config( self.vllm_config, use_v1=False) self.mamba_cache = MambaCacheManager( self.vllm_config, diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 1c93e90737a..21c88be1514 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -525,7 +525,7 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, embedding_padding_modules = ["lm_head"] @classmethod - def get_mamba_state_shape_from_config( + def get_static_cache_shape_from_config( cls, vllm_config: "VllmConfig", use_v1: bool = True, @@ -620,7 +620,7 @@ def forward(self, self.vllm_config.parallel_config, LayerBlockType.mamba)) mamba_state_shape = \ - self.get_mamba_state_shape_from_config( + self.get_static_cache_shape_from_config( self.vllm_config, use_v1=False) self.mamba_cache = MambaCacheManager(self.vllm_config, self.model_config.dtype, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index b60f1a5b6ff..06c838af861 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -446,12 +446,13 @@ class IsHybrid(Protocol): 'layers_block_type' """ @classmethod - def get_mamba_state_shape_from_config( + def get_static_cache_shape_from_config( cls, vllm_config: "VllmConfig", use_v1: bool = True, - ) -> tuple[tuple[int, int], tuple[int, int, int]]: - """Calculate shapes for Mamba's convolutional and state caches. + ) -> tuple[tuple[int, int], ...]: + """Calculate shapes for static caches. Currently used for + convolutional and/or SSM state caches (e.g. Mamba, ShortConv). Args: vllm_config: vLLM config diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py new file mode 100755 index 00000000000..0e7b15c86fd --- /dev/null +++ b/vllm/model_executor/models/lfm2.py @@ -0,0 +1,599 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import Any, Optional + +import torch +import torch.nn as nn +from transformers import Lfm2Config + +from vllm import envs +from vllm.attention import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.conv import ShortConv +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + Mamba2Metadata, prepare_mamba2_metadata) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.conv_cache import (ConvCacheManager, + ConvCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsQuant) +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class Lfm2MLP(nn.Module): + + def __init__( + self, + dim: int, + ff_dim: int, + multiple_of: int, + auto_adjust_ff_dim: bool, + ffn_dim_multiplier: Optional[float], + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + if auto_adjust_ff_dim: + ff_dim = int(2 * ff_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + ff_dim = int(ffn_dim_multiplier * ff_dim) + ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of) + + self.w1 = MergedColumnParallelLinear( + input_size=dim, + output_sizes=[ff_dim] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.w2 = RowParallelLinear( + input_size=ff_dim, + output_size=dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.w1(x) + x = self.act_fn(gate_up) + x, _ = self.w2(x) + return x + + +class Lfm2Attention(nn.Module): + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = hidden_size + self.num_kv_heads = num_kv_heads + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.q_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + self.k_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + n_tokens, _ = hidden_states.shape + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(n_tokens, self.num_heads, self.head_dim).contiguous() + k = k.view(n_tokens, self.num_kv_heads, self.head_dim).contiguous() + q = self.q_layernorm(q) + k = self.k_layernorm(k) + q, k = self.rotary_emb(positions, q, k) + q = q.view(n_tokens, self.num_heads * self.head_dim) + k = k.view(n_tokens, self.num_kv_heads * self.head_dim) + attn_output = self.attn(q, k, v) + output, _ = self.out_proj(attn_output) + return output + + +class Lfm2AttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.prefix = prefix + self.config = config + self.layer_idx = layer_idx + + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + + self.self_attn = Lfm2Attention( + config=config, + layer_idx=layer_idx, + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + self.feed_forward = Lfm2MLP( + dim=config.block_dim, + ff_dim=config.block_ff_dim, + multiple_of=config.block_multiple_of, + auto_adjust_ff_dim=config.block_auto_adjust_ff_dim, + ffn_dim_multiplier=config.block_ffn_dim_multiplier, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + hidden_states, residual = self.ffn_norm(hidden_states, residual) + return self.feed_forward(hidden_states), residual + + +class Lfm2ShortConvDecoderLayer(nn.Module): + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.conv = ShortConv( + config=config, + dim=config.conv_dim, + layer_idx=layer_idx, + prefix=f"{prefix}.conv", + ) + + self.feed_forward = Lfm2MLP( + dim=config.block_dim, + ff_dim=config.block_ff_dim, + multiple_of=config.block_multiple_of, + auto_adjust_ff_dim=config.block_auto_adjust_ff_dim, + ffn_dim_multiplier=config.block_ffn_dim_multiplier, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + conv_cache_params: ConvCacheParams, + conv_metadata: Mamba2Metadata, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm( + hidden_states, residual) + hidden_states = self.conv( + hidden_states, + conv_cache_params=conv_cache_params, + conv_metadata=conv_metadata, + ) + hidden_states, residual = self.ffn_norm(hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class Lfm2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size) + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + is_attn = self.config.layer_types[layer_idx] == "full_attention" + layer_class = (Lfm2AttentionDecoderLayer + if is_attn else Lfm2ShortConvDecoderLayer) + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + if get_pp_group().is_last_rank: + self.embedding_norm = RMSNorm(config.hidden_size, + eps=config.norm_eps) + else: + self.embedding_norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + conv_cache_params: ConvCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + attn_metadata = get_forward_context().attn_metadata + + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=1, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + kv_cache_index = 0 + state_cache_index = 0 + for layer in self.layers[self.start_layer:self.end_layer]: + layer_conv_cache_params = None + if isinstance(layer, Lfm2AttentionDecoderLayer): + kv_cache_index += 1 + if isinstance(layer, Lfm2ShortConvDecoderLayer): + current_state_layer = state_cache_index + layer_conv_cache_params = conv_cache_params.at_layer_idx( + current_state_layer) if conv_cache_params else None + state_cache_index += 1 + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + conv_cache_params=layer_conv_cache_params, + conv_metadata=mamba2_metadata, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.embedding_norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".w1", ".w1", 0), + (".w1", ".w3", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "w1": [ + "w1", + "w3", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + @classmethod + def get_static_cache_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int]]: + """ Calculate shapes for LFM2's convolutional cache. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + world_size = parallel_config.tensor_parallel_size + hidden_size = hf_config.conv_dim + conv_L_cache = hf_config.conv_L_cache + conv_state_shape = ( + conv_L_cache - 1, + hidden_size // world_size, + ) + if not use_v1: + conv_state_shape = (conv_state_shape[1], conv_state_shape[0]) + + return (conv_state_shape, ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert (not cache_config.enable_prefix_caching + ), "Lfm2 currently does not support prefix caching" + + super().__init__() + self.config = config + self.vllm_config = vllm_config + self.scheduler_config = scheduler_config + self.model_config = vllm_config.model_config + + self.model = Lfm2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = self.config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + else: + self.lm_head = PPMissingLayer() + + # Used to track and store by the Mamba cache between steps. + self.lfm2_cache: Optional[ConvCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + conv_cache_params = None + if not envs.VLLM_USE_V1: + if self.lfm2_cache is None: + num_conv_layers = \ + self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, + LayerBlockType.conv + ) + conv_shape = self.get_static_cache_shape_from_config( + self.vllm_config, use_v1=False) + self.lfm2_cache = ConvCacheManager( + vllm_config=self.vllm_config, + dtype=self.lm_head.weight.dtype, + num_conv_layers=num_conv_layers, + conv_state_shape=conv_shape[0], + ) + + conv_cache_params = self.lfm2_cache.current_run_tensors(**kwargs) + + hidden_states = self.model(input_ids, positions, conv_cache_params, + intermediate_tensors, inputs_embeds) + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.lfm2_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.lfm2_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index d812d8cc0a3..b3d83ab0800 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -198,7 +198,7 @@ def load_weights(self, weights: Iterable[tuple[str, class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): @classmethod - def get_mamba_state_shape_from_config( + def get_static_cache_shape_from_config( cls, vllm_config: "VllmConfig", use_v1: bool = True, @@ -285,7 +285,7 @@ def forward(self, self.vllm_config.parallel_config, LayerBlockType.mamba)) mamba_state_shape = \ - self.get_mamba_state_shape_from_config( + self.get_static_cache_shape_from_config( self.vllm_config, use_v1=False) self.mamba_cache = MambaCacheManager(self.vllm_config, self.lm_head.weight.dtype, diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index cf7b39db1fe..aed778ed5f9 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -460,7 +460,7 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, embedding_padding_modules = ["lm_head"] @classmethod - def get_mamba_state_shape_from_config( + def get_static_cache_shape_from_config( cls, vllm_config: "VllmConfig", use_v1: bool = True, @@ -548,7 +548,7 @@ def forward(self, LayerBlockType.mamba ) mamba_state_shape = \ - self.get_mamba_state_shape_from_config( + self.get_static_cache_shape_from_config( self.vllm_config, use_v1=False) self.mamba_cache = MambaCacheManager(self.vllm_config, self.lm_head.weight.dtype, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index fd831727ab2..c1ff88a12c8 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -84,6 +84,7 @@ "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), + "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index ebf8dd497f6..f358c04a4a4 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -844,7 +844,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): }) @classmethod - def get_mamba_state_shape_from_config( + def get_static_cache_shape_from_config( cls, vllm_config: "VllmConfig", use_v1: bool = True, @@ -959,7 +959,7 @@ def forward(self, if self.mamba_cache is None: num_mamba_layers = self.config.num_hidden_layers mamba_state_shape = \ - self.get_mamba_state_shape_from_config( + self.get_static_cache_shape_from_config( self.vllm_config, use_v1=False) self.mamba_cache = MambaCacheManager(self.vllm_config, self.lm_head.weight.dtype, diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index bbcc2a523dc..eef449286d7 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -225,6 +225,7 @@ class Device(enum.Enum): class LayerBlockType(enum.Enum): attention = "attention" mamba = "mamba" + conv = "conv" class Counter: diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index dca5de46c06..7ae09bccb25 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -11,7 +11,7 @@ from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec, ShortConvSpec if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -89,11 +89,17 @@ class Mamba2AttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, device: torch.device): - assert isinstance(kv_cache_spec, MambaSpec) self.kv_cache_spec = kv_cache_spec - self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() - assert self.chunk_size is not None, ( - "chunk_size needs to be set in the model config for Mamba2 models") + if isinstance(kv_cache_spec, MambaSpec): + self.chunk_size = \ + vllm_config.model_config.get_mamba_chunk_size() + assert self.chunk_size is not None, ( + "chunk_size needs to be set in the model config for " + "Mamba2 models") + elif isinstance(kv_cache_spec, ShortConvSpec): + self.chunk_size = 1 + else: + raise ValueError(f"Unsupported kv_cache_spec: {kv_cache_spec}") def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 1560406c900..f6b25382aa3 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -9,7 +9,8 @@ from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, FullAttentionSpec, KVCacheSpec, - MambaSpec, SlidingWindowSpec) + MambaSpec, ShortConvSpec, + SlidingWindowSpec) from vllm.v1.request import Request @@ -406,9 +407,8 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, ) -> tuple[list[KVCacheBlock], ...]: - assert isinstance( - kv_cache_spec, - MambaSpec), ("MambaManager can only be used for mamba groups") + assert isinstance(kv_cache_spec, (MambaSpec, ShortConvSpec)), ( + "MambaManager can only be used for mamba/shortconv groups") # Prefix caching is not supported for mamba now. Always return empty # list. computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( @@ -438,6 +438,7 @@ def allocate_new_blocks(self, request_id: str, ChunkedLocalAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, MambaSpec: MambaManager, + ShortConvSpec: MambaManager, } diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 6726709955f..6bf4ef0b998 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -171,7 +171,7 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: @dataclass -class MambaSpec(KVCacheSpec): +class StaticCacheSpec(KVCacheSpec): shapes: tuple[tuple[int, ...], ...] dtype: torch.dtype page_size_padded: Optional[int] = None @@ -181,7 +181,8 @@ def __post_init__(self): @property def type_id(self) -> str: - return f"mamba_{self.shapes}_{self.dtype}" + raise NotImplementedError( + "Please instantiate a subclass of StaticCacheSpec") @property def page_size_bytes(self) -> int: @@ -198,6 +199,22 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: return self.page_size_bytes +@dataclass +class MambaSpec(StaticCacheSpec): + + @property + def type_id(self) -> str: + return f"mamba_{self.shapes}_{self.dtype}" + + +@dataclass +class ShortConvSpec(StaticCacheSpec): + + @property + def type_id(self) -> str: + return f"short_conv_{self.shapes}_{self.dtype}" + + @dataclass class KVCacheTensor: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c3eeb6c2e39..b3874aa3a89 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -29,6 +29,7 @@ from vllm.forward_context import (DPMetadata, get_forward_context, set_forward_context) from vllm.logger import init_logger +from vllm.model_executor.layers.conv import ShortConv from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader @@ -52,7 +53,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, FullAttentionSpec, KVCacheConfig, - KVCacheSpec, MambaSpec, + KVCacheSpec, MambaSpec, ShortConvSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) @@ -2351,7 +2352,9 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: raise NotImplementedError( "Non-Attention backend is not supported by V1 " "GPUModelRunner.") - elif isinstance(kv_cache_spec, MambaSpec): + elif isinstance(kv_cache_spec, (MambaSpec, ShortConvSpec)): + # ShortConv uses many of the same attributes as Mamba2 path, + # except chunking attn_backend_i = Mamba2AttentionBackend else: raise ValueError( @@ -2447,7 +2450,7 @@ def _reshape_kv_cache_tensors( corresponding memory buffer for KV cache. """ kv_caches: dict[str, torch.Tensor] = {} - has_attn, has_mamba = False, False + has_attn, has_static_cache_layers = False, False for i, kv_cache_group_spec in enumerate( kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec @@ -2485,8 +2488,8 @@ def _reshape_kv_cache_tensors( kv_caches[layer_name] = kv_cache_raw_tensors[ layer_name].view(dtype).view(kv_cache_shape).permute( *inv_order) - elif isinstance(kv_cache_spec, MambaSpec): - has_mamba = True + elif isinstance(kv_cache_spec, (MambaSpec, ShortConvSpec)): + has_static_cache_layers = True raw_tensor = kv_cache_raw_tensors[layer_name] dtype = kv_cache_spec.dtype num_element_per_page = (kv_cache_spec.page_size_bytes // @@ -2510,13 +2513,13 @@ def _reshape_kv_cache_tensors( else: raise NotImplementedError - if has_attn and has_mamba: - self._verify_hybrid_attention_mamba_layout(kv_cache_config, - kv_cache_raw_tensors) + if has_attn and has_static_cache_layers: + self._verify_hybrid_attention_static_cache_layout( + kv_cache_config, kv_cache_raw_tensors) return kv_caches - def _verify_hybrid_attention_mamba_layout( + def _verify_hybrid_attention_static_cache_layout( self, kv_cache_config: KVCacheConfig, kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None: """ @@ -2662,27 +2665,36 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") - mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) - if len(mamba_layers) > 0: + for layer_cls, spec_cls in zip((MambaBase, ShortConv), + (MambaSpec, ShortConvSpec)): + static_cache_layers = get_layers_from_vllm_config( + self.vllm_config, layer_cls) + if len(static_cache_layers) > 0: + break + + if len(static_cache_layers) > 0: if self.vllm_config.speculative_config is not None: raise NotImplementedError( - "Mamba with speculative decoding is not supported yet.") + "Static cache models with speculative decoding is not " + "yet supported.") if not self.vllm_config.model_config.enforce_eager: raise NotImplementedError( - "Mamba with cuda graph is not supported yet.") + "Static cache models with cuda graph is not " + "yet supported.") if self.vllm_config.cache_config.enable_prefix_caching: raise NotImplementedError( - "Prefix caching is not supported for Mamba yet.") + "Prefix caching for static cache models is not " + "yet supported.") max_model_len = self.vllm_config.model_config.max_model_len page_size_padded = ( - self.vllm_config.cache_config.mamba_page_size_padded) + self.vllm_config.cache_config.static_cache_page_size_padded) - # Set block_size to max_model_len, so that mamba model will always - # have only one block in the KV cache. - for layer_name, mamba_module in mamba_layers.items(): - kv_cache_spec[layer_name] = MambaSpec( - shapes=mamba_module.get_state_shape(), + # Set block_size to max_model_len, so that the static cache / + # hybrid model will always have only one block in the KV cache. + for layer_name, static_cache_module in static_cache_layers.items(): + kv_cache_spec[layer_name] = spec_cls( + shapes=static_cache_module.get_state_shape(), dtype=self.kv_cache_dtype, block_size=max_model_len, page_size_padded=page_size_padded)