-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
Map Mistral-HF models back onto Mistral format on-the-fly #20471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
4892156
821a7dc
00834b2
ec9a53b
01c2d6c
39ac816
7bb4b13
68d81a9
97acc5e
c5d28d0
36a2846
f277cf0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,6 +1,7 @@ | ||||||
# SPDX-License-Identifier: Apache-2.0 | ||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||
|
||||||
import re | ||||||
import math | ||||||
from collections.abc import Iterable, Mapping, Sequence | ||||||
from dataclasses import dataclass, fields | ||||||
|
@@ -52,6 +53,11 @@ | |||||
merge_multimodal_embeddings) | ||||||
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs | ||||||
|
||||||
import logging | ||||||
sjuxax marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
import os | ||||||
logger = logging.getLogger(__name__) | ||||||
logger.setLevel(os.getenv("VLLM_LOGGING_LEVEL", "INFO")) | ||||||
|
||||||
try: | ||||||
from xformers import ops as xops | ||||||
USE_XFORMERS_OPS = True | ||||||
|
@@ -82,6 +88,7 @@ | |||||
super().__init__() | ||||||
|
||||||
self.tokenizer = tokenizer | ||||||
self.logger: logging.Logger = logger | ||||||
|
||||||
@property | ||||||
def image_processor(self) -> ImageEncoder: | ||||||
|
@@ -334,12 +341,15 @@ | |||||
|
||||||
raise ValueError("Only image modality is supported") | ||||||
|
||||||
packed_modules_mapping = {} | ||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||||||
super().__init__() | ||||||
config = vllm_config.model_config.hf_config | ||||||
multimodal_config = vllm_config.model_config.multimodal_config | ||||||
self.config = config | ||||||
self.multimodal_config = multimodal_config | ||||||
|
||||||
self.logger: logging.Logger = logger | ||||||
|
||||||
dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} | ||||||
vision_args = { | ||||||
|
@@ -364,11 +374,13 @@ | |||||
eps=1e-5) | ||||||
|
||||||
if self.vision_args.mm_projector_id == PATCH_MERGE: | ||||||
self.logger.debug("PatchMerger initalizing ...") | ||||||
self.patch_merger = PatchMerger( | ||||||
vision_encoder_dim=self.vision_args.hidden_size, | ||||||
spatial_merge_size=self.vision_args.spatial_merge_size, | ||||||
use_mlp_bias=False, | ||||||
) | ||||||
self.logger.debug("PatchMerger:\n\t%s", self.patch_merger) | ||||||
|
||||||
self.vision_language_adapter = VisionLanguageAdapter( | ||||||
self.vision_args, dim=config.text_config.hidden_size) | ||||||
|
@@ -480,8 +492,45 @@ | |||||
return self.language_model.compute_logits(hidden_states, | ||||||
sampling_metadata) | ||||||
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | ||||||
# Reverse mapping from HF to original Pixtral format | ||||||
MISTRAL3_REVERSE_MAPPING = { | ||||||
r"^language_model\.lm_head\.weight": r"output.weight", | ||||||
r"^language_model\.model\.norm\.weight": r"norm.weight", | ||||||
r"^language_model\.model\.embed_tokens\.weight": r"tok_embeddings.weight", | ||||||
r"^language_model\.model\.layers\.(\d+)\.input_layernorm\.weight": r"layers.\1.attention_norm.weight", | ||||||
r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm\.weight": r"layers.\1.ffn_norm.weight", | ||||||
r"^language_model\.model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj\.weight": r"layers.\1.attention.w\2.weight", | ||||||
r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.weight": r"layers.\1.feed_forward.w1.weight", | ||||||
r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.weight": r"layers.\1.feed_forward.w2.weight", | ||||||
r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.weight": r"layers.\1.feed_forward.w3.weight", | ||||||
r"^vision_tower\.transformer\.layers\.(\d+)\.attention_norm\.weight": r"vision_encoder.transformer.layers.\1.attention_norm.weight", | ||||||
r"^vision_tower\.transformer\.layers\.(\d+)\.ffn_norm\.weight": r"vision_encoder.transformer.layers.\1.ffn_norm.weight", | ||||||
r"^vision_tower\.transformer\.layers\.(\d+)\.attention\.(q|k|v|o)_proj\.weight": r"vision_encoder.transformer.layers.\1.attention.w\2.weight", | ||||||
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.gate_proj\.weight": r"vision_encoder.transformer.layers.\1.feed_forward.w1.weight", | ||||||
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.down_proj\.weight": r"vision_encoder.transformer.layers.\1.feed_forward.w2.weight", | ||||||
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.up_proj\.weight": r"vision_encoder.transformer.layers.\1.feed_forward.w3.weight", | ||||||
r"^multi_modal_projector\.linear_1": r"vision_language_adapter.w_in", | ||||||
r"^multi_modal_projector\.linear_2": r"vision_language_adapter.w_out", | ||||||
r"^vision_tower\.ln_pre\.weight": r"vision_encoder.ln_pre.weight", | ||||||
r"^vision_tower\.patch_conv\.weight": r"vision_encoder.patch_conv.weight", | ||||||
r"^multi_modal_projector\.patch_merger\.merging_layer\.weight": r"patch_merger.merging_layer.weight", | ||||||
r"^multi_modal_projector\.norm\.weight": r"pre_mm_projector_norm.weight", | ||||||
r"^language_model\.model\.layers\.(\d+)\.(.+)\.(g_idx|zp|scales|zeros|qweight|qzeros)$": r"layers.\1.\2.\3" | ||||||
sjuxax marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
} | ||||||
|
||||||
def maybe_remap_mistral3(self, name: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]: | ||||||
"""Remap HF-style weight names back to original Pixtral format.""" | ||||||
self.logger.debug(f"Considering {name}") | ||||||
|
||||||
for pattern, replacement in self.MISTRAL3_REVERSE_MAPPING.items(): | ||||||
new_name, n_replace = re.subn(pattern, replacement, name) | ||||||
if n_replace > 0: | ||||||
self.logger.debug(f"Remapped {name} to {new_name}") | ||||||
return new_name, tensor | ||||||
return name, tensor # Return unchanged if no match | ||||||
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | ||||||
|
||||||
def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): | ||||||
return weight[0].startswith("vision_encoder") | ||||||
|
||||||
|
@@ -496,21 +545,37 @@ | |||||
|
||||||
# Get references to parameters for direct loading | ||||||
vision_encoder_dict = dict(self.vision_encoder.named_parameters()) | ||||||
patch_merger_dict = dict(self.patch_merger.named_parameters( | ||||||
)) if self.vision_args.mm_projector_id == PATCH_MERGE else dict() | ||||||
pre_mm_projector_norm_dict = dict( | ||||||
self.pre_mm_projector_norm.named_parameters( | ||||||
)) if self.vision_args.add_pre_mm_projector_layer_norm else dict() | ||||||
vision_lang_adapter_dict = dict( | ||||||
self.vision_language_adapter.named_parameters()) | ||||||
|
||||||
def inverse_permute_for_rope(tensor, n_heads, dim1, dim2): | ||||||
"""Reverse the permutation applied for ROPE in HF format.""" | ||||||
tensor = tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2) | ||||||
tensor = tensor.transpose(1, 2) | ||||||
tensor = tensor.reshape(dim1, dim2) | ||||||
return tensor | ||||||
|
||||||
def llm_weights_generator(): | ||||||
# Single pass over weights | ||||||
for name, w in weights: | ||||||
remapped_weights = (self.maybe_remap_mistral3(name, w) for name, w in weights) | ||||||
for name, w in remapped_weights: | ||||||
if is_vision_encoder_weights((name, w)): | ||||||
# Load vision encoder weights directly | ||||||
trimmed_name = '.'.join(name.split(".")[1:]) | ||||||
param = vision_encoder_dict[trimmed_name] | ||||||
if trimmed_name.startswith("ln_pre"): | ||||||
logger.debug("loading ln_pre weight now ...") | ||||||
if "wq.weight" in trimmed_name or "wk.weight" in trimmed_name: | ||||||
n_heads = self.vision_args.num_attention_heads | ||||||
dim1 = param.shape[0] # num_heads * head_dim | ||||||
dim2 = param.shape[1] # hidden_size | ||||||
w = inverse_permute_for_rope(w, n_heads, dim1, dim2) | ||||||
logger.debug(f"Reversed permute_for_rope for {name}, sample: {w[:5, :5]}") | ||||||
with torch.no_grad(): | ||||||
default_weight_loader(param, w) | ||||||
elif is_patch_merger((name, w)): | ||||||
|
@@ -531,13 +596,13 @@ | |||||
param = vision_lang_adapter_dict[trimmed_name] | ||||||
with torch.no_grad(): | ||||||
default_weight_loader(param, w) | ||||||
else: | ||||||
# LLM weights: yield them to be loaded | ||||||
# by language_model.load_weights | ||||||
yield (name, w) | ||||||
|
||||||
# Now we call the language model load with the generator | ||||||
self.language_model.load_weights(llm_weights_generator()) | ||||||
Check failure on line 605 in vllm/model_executor/models/pixtral.py
|
||||||
|
||||||
|
||||||
# Vision encoder | ||||||
|
@@ -554,7 +619,7 @@ | |||||
image_token_id: int | ||||||
adapter_bias: bool = True | ||||||
spatial_merge_size: int = 1 | ||||||
add_pre_mm_projector_layer_norm: bool = False | ||||||
add_pre_mm_projector_layer_norm: bool = True | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changing the default value of
Suggested change
|
||||||
mm_projector_id: str = "" | ||||||
|
||||||
|
||||||
|
@@ -797,6 +862,8 @@ | |||||
|
||||||
# flatten to a single sequence | ||||||
patch_embeds = torch.cat(patch_embeds, dim=1) | ||||||
# _ = self.ln_pre.weight.data.fill_(1.0) | ||||||
# logger.debug("Skipping ln_pre for now ...") | ||||||
sjuxax marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
patch_embeds = self.ln_pre(patch_embeds) | ||||||
|
||||||
# positional embeddings | ||||||
|
@@ -848,6 +915,9 @@ | |||||
|
||||||
mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2) | ||||||
|
||||||
print("mlp_input_dim = {vision_encoder_dim} * ({spatial_merge_size}**2)") | ||||||
print(f"mlp_input_dim = {vision_encoder_dim} * ({spatial_merge_size}**2)") | ||||||
sjuxax marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
self.spatial_merge_size = spatial_merge_size | ||||||
self.mlp_input_dim = mlp_input_dim | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
quant_config
is hardcoded toNone
. Before finalizing, replace this with a dynamic check to ensure correctness for checkpoints that may have a quantizedmulti_modal_projector
.