Skip to content

Map Mistral-HF models back onto Mistral format on-the-fly #20471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/model_executor/models/mistral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def __init__(self,
self.linear_1 = ColumnParallelLinear(vision_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias,
quant_config=quant_config,
quant_config=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The quant_config is hardcoded to None. Before finalizing, replace this with a dynamic check to ensure correctness for checkpoints that may have a quantized multi_modal_projector.

prefix=f"{prefix}.linear_1")
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = RowParallelLinear(text_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias,
quant_config=quant_config,
quant_config=None,
prefix=f"{prefix}.linear_2")

def forward(self, image_features: torch.Tensor,
Expand Down
76 changes: 73 additions & 3 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import re
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields
Expand Down Expand Up @@ -52,6 +53,11 @@
merge_multimodal_embeddings)
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs

import logging
import os
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VLLM_LOGGING_LEVEL", "INFO"))

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
Expand Down Expand Up @@ -82,6 +88,7 @@
super().__init__()

self.tokenizer = tokenizer
self.logger: logging.Logger = logger

@property
def image_processor(self) -> ImageEncoder:
Expand Down Expand Up @@ -334,12 +341,15 @@

raise ValueError("Only image modality is supported")

packed_modules_mapping = {}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config

self.logger: logging.Logger = logger

dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
vision_args = {
Expand All @@ -364,11 +374,13 @@
eps=1e-5)

if self.vision_args.mm_projector_id == PATCH_MERGE:
self.logger.debug("PatchMerger initalizing ...")
self.patch_merger = PatchMerger(
vision_encoder_dim=self.vision_args.hidden_size,
spatial_merge_size=self.vision_args.spatial_merge_size,
use_mlp_bias=False,
)
self.logger.debug("PatchMerger:\n\t%s", self.patch_merger)

self.vision_language_adapter = VisionLanguageAdapter(
self.vision_args, dim=config.text_config.hidden_size)
Expand Down Expand Up @@ -480,8 +492,45 @@
return self.language_model.compute_logits(hidden_states,
sampling_metadata)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# Reverse mapping from HF to original Pixtral format
MISTRAL3_REVERSE_MAPPING = {
r"^language_model\.lm_head\.weight": r"output.weight",
r"^language_model\.model\.norm\.weight": r"norm.weight",
r"^language_model\.model\.embed_tokens\.weight": r"tok_embeddings.weight",
r"^language_model\.model\.layers\.(\d+)\.input_layernorm\.weight": r"layers.\1.attention_norm.weight",
r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm\.weight": r"layers.\1.ffn_norm.weight",
r"^language_model\.model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj\.weight": r"layers.\1.attention.w\2.weight",
r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.weight": r"layers.\1.feed_forward.w1.weight",
r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.weight": r"layers.\1.feed_forward.w2.weight",
r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.weight": r"layers.\1.feed_forward.w3.weight",
r"^vision_tower\.transformer\.layers\.(\d+)\.attention_norm\.weight": r"vision_encoder.transformer.layers.\1.attention_norm.weight",
r"^vision_tower\.transformer\.layers\.(\d+)\.ffn_norm\.weight": r"vision_encoder.transformer.layers.\1.ffn_norm.weight",
r"^vision_tower\.transformer\.layers\.(\d+)\.attention\.(q|k|v|o)_proj\.weight": r"vision_encoder.transformer.layers.\1.attention.w\2.weight",
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.gate_proj\.weight": r"vision_encoder.transformer.layers.\1.feed_forward.w1.weight",
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.down_proj\.weight": r"vision_encoder.transformer.layers.\1.feed_forward.w2.weight",
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.up_proj\.weight": r"vision_encoder.transformer.layers.\1.feed_forward.w3.weight",
r"^multi_modal_projector\.linear_1": r"vision_language_adapter.w_in",
r"^multi_modal_projector\.linear_2": r"vision_language_adapter.w_out",
r"^vision_tower\.ln_pre\.weight": r"vision_encoder.ln_pre.weight",
r"^vision_tower\.patch_conv\.weight": r"vision_encoder.patch_conv.weight",
r"^multi_modal_projector\.patch_merger\.merging_layer\.weight": r"patch_merger.merging_layer.weight",
r"^multi_modal_projector\.norm\.weight": r"pre_mm_projector_norm.weight",
r"^language_model\.model\.layers\.(\d+)\.(.+)\.(g_idx|zp|scales|zeros|qweight|qzeros)$": r"layers.\1.\2.\3"
}

def maybe_remap_mistral3(self, name: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]:
"""Remap HF-style weight names back to original Pixtral format."""
self.logger.debug(f"Considering {name}")

for pattern, replacement in self.MISTRAL3_REVERSE_MAPPING.items():
new_name, n_replace = re.subn(pattern, replacement, name)
if n_replace > 0:
self.logger.debug(f"Remapped {name} to {new_name}")
return new_name, tensor
return name, tensor # Return unchanged if no match

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("vision_encoder")

Expand All @@ -496,21 +545,37 @@

# Get references to parameters for direct loading
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
patch_merger_dict = dict(self.patch_merger.named_parameters(

Check failure on line 548 in vllm/model_executor/models/pixtral.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/model_executor/models/pixtral.py:548:27: G004 Logging statement uses f-string
)) if self.vision_args.mm_projector_id == PATCH_MERGE else dict()
pre_mm_projector_norm_dict = dict(
self.pre_mm_projector_norm.named_parameters(
)) if self.vision_args.add_pre_mm_projector_layer_norm else dict()
vision_lang_adapter_dict = dict(

Check failure on line 553 in vllm/model_executor/models/pixtral.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/model_executor/models/pixtral.py:553:35: G004 Logging statement uses f-string
self.vision_language_adapter.named_parameters())

def inverse_permute_for_rope(tensor, n_heads, dim1, dim2):
"""Reverse the permutation applied for ROPE in HF format."""
tensor = tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2)
tensor = tensor.transpose(1, 2)
tensor = tensor.reshape(dim1, dim2)
return tensor

def llm_weights_generator():
# Single pass over weights
for name, w in weights:
remapped_weights = (self.maybe_remap_mistral3(name, w) for name, w in weights)
for name, w in remapped_weights:
if is_vision_encoder_weights((name, w)):
# Load vision encoder weights directly
trimmed_name = '.'.join(name.split(".")[1:])
param = vision_encoder_dict[trimmed_name]
if trimmed_name.startswith("ln_pre"):
logger.debug("loading ln_pre weight now ...")
if "wq.weight" in trimmed_name or "wk.weight" in trimmed_name:
n_heads = self.vision_args.num_attention_heads
dim1 = param.shape[0] # num_heads * head_dim
dim2 = param.shape[1] # hidden_size
w = inverse_permute_for_rope(w, n_heads, dim1, dim2)
logger.debug(f"Reversed permute_for_rope for {name}, sample: {w[:5, :5]}")
with torch.no_grad():
default_weight_loader(param, w)
elif is_patch_merger((name, w)):
Expand All @@ -531,13 +596,13 @@
param = vision_lang_adapter_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
else:

Check failure on line 599 in vllm/model_executor/models/pixtral.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/pixtral.py:599:81: E501 Line too long (82 > 80)
# LLM weights: yield them to be loaded
# by language_model.load_weights
yield (name, w)

# Now we call the language model load with the generator
self.language_model.load_weights(llm_weights_generator())

Check failure on line 605 in vllm/model_executor/models/pixtral.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/pixtral.py:605:81: E501 Line too long (88 > 80)

Check failure on line 605 in vllm/model_executor/models/pixtral.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/model_executor/models/pixtral.py:605:29: G004 Logging statement uses f-string


# Vision encoder
Expand All @@ -554,7 +619,7 @@
image_token_id: int
adapter_bias: bool = True
spatial_merge_size: int = 1
add_pre_mm_projector_layer_norm: bool = False
add_pre_mm_projector_layer_norm: bool = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Changing the default value of add_pre_mm_projector_layer_norm to True is a breaking change for earlier Pixtral models. Revert this change and implement a mechanism to dynamically determine this value from the model's config.json.

Suggested change
add_pre_mm_projector_layer_norm: bool = True
add_pre_mm_projector_layer_norm: bool = False

mm_projector_id: str = ""


Expand Down Expand Up @@ -797,6 +862,8 @@

# flatten to a single sequence
patch_embeds = torch.cat(patch_embeds, dim=1)
# _ = self.ln_pre.weight.data.fill_(1.0)
# logger.debug("Skipping ln_pre for now ...")
patch_embeds = self.ln_pre(patch_embeds)

# positional embeddings
Expand Down Expand Up @@ -848,6 +915,9 @@

mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2)

print("mlp_input_dim = {vision_encoder_dim} * ({spatial_merge_size}**2)")
print(f"mlp_input_dim = {vision_encoder_dim} * ({spatial_merge_size}**2)")

self.spatial_merge_size = spatial_merge_size
self.mlp_input_dim = mlp_input_dim

Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
"Mistral3ForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501

Check failure on line 219 in vllm/model_executor/models/registry.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F601)

vllm/model_executor/models/registry.py:219:5: F601 Dictionary key literal `"Mistral3ForConditionalGeneration"` repeated
"QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501
Expand Down
Loading