From 48921569b7023b04d8d1ed24e7e55de129c5c9b9 Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Wed, 19 Mar 2025 19:40:24 -0600 Subject: [PATCH 01/12] Register Mistral3ForConditionalGeneration as Pixtral --- vllm/model_executor/models/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b100fe77e377..4ed05269f08e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -216,6 +216,7 @@ "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 + "Mistral3ForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 From 821a7dc905a2131332f3c33a15182bd3e201e6ff Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Thu, 20 Mar 2025 13:47:30 -0600 Subject: [PATCH 02/12] Use Transformers Mistral3-based checkpoints as Pixtral/3.1 Small --- vllm/model_executor/models/pixtral.py | 95 +++++++++++++++++++++++++- vllm/model_executor/models/registry.py | 2 +- 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 475d65a58b2a..8353ae324485 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import re import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields @@ -52,6 +53,11 @@ merge_multimodal_embeddings) from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs +import logging +import os +logger = logging.getLogger(__name__) +logger.setLevel(os.environ["VLLM_LOGGING_LEVEL"]) + try: from xformers import ops as xops USE_XFORMERS_OPS = True @@ -82,6 +88,7 @@ def __init__(self, tokenizer: MistralTokenizer) -> None: super().__init__() self.tokenizer = tokenizer + self.logger: logging.Logger = logger @property def image_processor(self) -> ImageEncoder: @@ -147,6 +154,10 @@ def __call__( image_processed = torch.tensor(image_inputs.image) image_tokens = torch.tensor(image_inputs.tokens) + self.logger.debug(f"Image processed shape: {image_processed.shape}") + self.logger.debug(f"Image processed sample (first 15x15 pixels, channel 0): {image_processed[0, :15, :15]}") + self.logger.debug(f"Image tokens: {image_tokens.shape}, first few: {image_tokens[:28]}") + images_processed.append(image_processed) images_tokens.append(image_tokens) @@ -334,12 +345,15 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: raise ValueError("Only image modality is supported") + packed_modules_mapping = {} def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config + + self.logger: logging.Logger = logger dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} vision_args = { @@ -364,11 +378,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): eps=1e-5) if self.vision_args.mm_projector_id == PATCH_MERGE: + self.logger.debug("PatchMerger initalizing ...") self.patch_merger = PatchMerger( vision_encoder_dim=self.vision_args.hidden_size, spatial_merge_size=self.vision_args.spatial_merge_size, use_mlp_bias=False, ) + self.logger.debug("PatchMerger:\n\t%s", self.patch_merger) self.vision_language_adapter = VisionLanguageAdapter( self.vision_args, dim=config.text_config.hidden_size) @@ -376,6 +392,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) + logger.debug("Pixtral model layout:") + import rich + from rich.console import Console + from rich.pretty import Pretty + cons = Console() + cons.print("[bold on dark_green] Pixtral model layout") + cons.print(Pretty(self)) + def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[PixtralImagePixelInputs]: images = kwargs.pop("images", None) @@ -480,8 +504,45 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + # Reverse mapping from HF to original Pixtral format + MISTRAL3_REVERSE_MAPPING = { + r"^language_model\.lm_head\.weight": r"output.weight", + r"^language_model\.model\.norm\.weight": r"norm.weight", + r"^language_model\.model\.embed_tokens\.weight": r"tok_embeddings.weight", + r"^language_model\.model\.layers\.(\d+)\.input_layernorm\.weight": r"layers.\1.attention_norm.weight", + r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm\.weight": r"layers.\1.ffn_norm.weight", + r"^language_model\.model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj\.weight": r"layers.\1.attention.w\2.weight", + r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.weight": r"layers.\1.feed_forward.w1.weight", + r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.weight": r"layers.\1.feed_forward.w2.weight", + r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.weight": r"layers.\1.feed_forward.w3.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.attention_norm\.weight": r"vision_encoder.transformer.layers.\1.attention_norm.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.ffn_norm\.weight": r"vision_encoder.transformer.layers.\1.ffn_norm.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.attention\.(q|k|v|o)_proj\.weight": r"vision_encoder.transformer.layers.\1.attention.w\2.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.gate_proj\.weight": r"vision_encoder.transformer.layers.\1.feed_forward.w1.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.down_proj\.weight": r"vision_encoder.transformer.layers.\1.feed_forward.w2.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.up_proj\.weight": r"vision_encoder.transformer.layers.\1.feed_forward.w3.weight", + r"^multi_modal_projector\.linear_1": r"vision_language_adapter.w_in", + r"^multi_modal_projector\.linear_2": r"vision_language_adapter.w_out", + r"^vision_tower\.ln_pre\.weight": r"vision_encoder.ln_pre.weight", + r"^vision_tower\.patch_conv\.weight": r"vision_encoder.patch_conv.weight", + r"^multi_modal_projector\.patch_merger\.merging_layer\.weight": r"patch_merger.merging_layer.weight", + r"^multi_modal_projector\.norm\.weight": r"pre_mm_projector_norm.weight", + r"^language_model\.model\.layers\.(\d+)\.(.+)\.(g_idx|zp|scales|zeros|qweight|qzeros)$": r"layers.\1.\2.\3" + } + + def maybe_remap_mistral3(self, name: str, tensor: torch.Tensor) -> Tuple[str, torch.Tensor]: + """Remap HF-style weight names back to original Pixtral format.""" + self.logger.debug(f"Considering {name}") + + for pattern, replacement in self.MISTRAL3_REVERSE_MAPPING.items(): + new_name, n_replace = re.subn(pattern, replacement, name) + if n_replace > 0: + self.logger.debug(f"Remapped {name} to {new_name}") + return new_name, tensor + return name, tensor # Return unchanged if no match + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): return weight[0].startswith("vision_encoder") @@ -504,13 +565,30 @@ def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]): vision_lang_adapter_dict = dict( self.vision_language_adapter.named_parameters()) + def inverse_permute_for_rope(tensor, n_heads, dim1, dim2): + """Reverse the permutation applied for ROPE in HF format.""" + tensor = tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2) + tensor = tensor.transpose(1, 2) + tensor = tensor.reshape(dim1, dim2) + return tensor + def llm_weights_generator(): # Single pass over weights - for name, w in weights: + remapped_weights = (self.maybe_remap_mistral3(name, w) for name, w in weights) + for name, w in remapped_weights: if is_vision_encoder_weights((name, w)): # Load vision encoder weights directly trimmed_name = '.'.join(name.split(".")[1:]) param = vision_encoder_dict[trimmed_name] + if trimmed_name.startswith("ln_pre"): + logger.debug("loading ln_pre weight now ...") + logger.debug(f"ln_pre weight load sample: {w[:5]}") + if "wq.weight" in trimmed_name or "wk.weight" in trimmed_name: + n_heads = self.vision_args.num_attention_heads + dim1 = param.shape[0] # num_heads * head_dim + dim2 = param.shape[1] # hidden_size + w = inverse_permute_for_rope(w, n_heads, dim1, dim2) + logger.debug(f"Reversed permute_for_rope for {name}, sample: {w[:5, :5]}") with torch.no_grad(): default_weight_loader(param, w) elif is_patch_merger((name, w)): @@ -534,6 +612,7 @@ def llm_weights_generator(): else: # LLM weights: yield them to be loaded # by language_model.load_weights + # self.logger.debug(f"Yielding weight {name}; shape {w.shape}") yield (name, w) # Now we call the language model load with the generator @@ -786,9 +865,12 @@ def forward( all tokens of all images of shape (N_toks, D) """ # pass images through initial convolution independently + logger.debug(f"patch_conv weights sample (first filter, channel 0): {self.patch_conv.weight[0, 0, :5, :5]}") patch_embeds_list = [ self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images ] + logger.debug(f"Raw patch conv output shape: {patch_embeds_list[0].shape}") + logger.debug(f"Raw patch conv output sample: {patch_embeds_list[0][0, :5, 0, :5]}") patch_embeds = [ p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list @@ -797,8 +879,14 @@ def forward( # flatten to a single sequence patch_embeds = torch.cat(patch_embeds, dim=1) + # _ = self.ln_pre.weight.data.fill_(1.0) + logger.debug(f"ln_pre weight sample: {self.ln_pre.weight[:5]}") + # logger.debug("Skipping ln_pre for now ...") patch_embeds = self.ln_pre(patch_embeds) + logger.debug(f"Patch embeddings shape after conv: {patch_embeds.shape}") + logger.debug(f"Patch embeddings sample: {patch_embeds[0, :5, :5]}") + # positional embeddings positions = position_meshgrid(patch_embeds_list).to(self.device) freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] @@ -848,6 +936,9 @@ def __init__( mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2) + print("mlp_input_dim = {vision_encoder_dim} * ({spatial_merge_size}**2)") + print(f"mlp_input_dim = {vision_encoder_dim} * ({spatial_merge_size}**2)") + self.spatial_merge_size = spatial_merge_size self.mlp_input_dim = mlp_input_dim diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 4ed05269f08e..717b2304c275 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -216,7 +216,7 @@ "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 - "Mistral3ForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 + "Mistral3ForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 From 00834b2013cf42a7a611956dbd514e6b1a28bb98 Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Thu, 20 Mar 2025 21:37:47 -0600 Subject: [PATCH 03/12] Mistral3.1: silu->GELU to match transformers definition --- vllm/model_executor/models/pixtral.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 8353ae324485..e358fee47cd1 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -711,9 +711,10 @@ def __init__(self, args: VisionEncoderArgs): self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + self.act_fn = nn.GELU() def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + return self.w2(self.act_fn(self.w1(x)) * self.w3(x)) class Attention(nn.Module): From ec9a53b0d7bb014be82875bd26e3c63fba41a732 Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Fri, 21 Mar 2025 09:51:08 -0600 Subject: [PATCH 04/12] Hack around `add_pre_mm_projector_layer_norm` incorrect detection. ... Note this probably breaks other Mistral-formatted Pixtrals. Real fix hopefully coming soon. --- vllm/model_executor/models/pixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index e358fee47cd1..41568ceb41eb 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -633,7 +633,7 @@ class VisionEncoderArgs: image_token_id: int adapter_bias: bool = True spatial_merge_size: int = 1 - add_pre_mm_projector_layer_norm: bool = False + add_pre_mm_projector_layer_norm: bool = True mm_projector_id: str = "" From 01c2d6c0c1698bd188124242b65b5e731f1eea8a Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Fri, 21 Mar 2025 10:23:15 -0600 Subject: [PATCH 05/12] Cleanup excessive/unneeded debug prints. ... We still keep our model loading debug prints, but we succesfully resolved the issues addressed by the samples of patch_conv, etc., so all they do now is spam the log. Remove them! --- vllm/model_executor/models/pixtral.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 41568ceb41eb..c839df8dfe74 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -154,10 +154,6 @@ def __call__( image_processed = torch.tensor(image_inputs.image) image_tokens = torch.tensor(image_inputs.tokens) - self.logger.debug(f"Image processed shape: {image_processed.shape}") - self.logger.debug(f"Image processed sample (first 15x15 pixels, channel 0): {image_processed[0, :15, :15]}") - self.logger.debug(f"Image tokens: {image_tokens.shape}, first few: {image_tokens[:28]}") - images_processed.append(image_processed) images_tokens.append(image_tokens) @@ -392,14 +388,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - logger.debug("Pixtral model layout:") - import rich - from rich.console import Console - from rich.pretty import Pretty - cons = Console() - cons.print("[bold on dark_green] Pixtral model layout") - cons.print(Pretty(self)) - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[PixtralImagePixelInputs]: images = kwargs.pop("images", None) @@ -582,7 +570,6 @@ def llm_weights_generator(): param = vision_encoder_dict[trimmed_name] if trimmed_name.startswith("ln_pre"): logger.debug("loading ln_pre weight now ...") - logger.debug(f"ln_pre weight load sample: {w[:5]}") if "wq.weight" in trimmed_name or "wk.weight" in trimmed_name: n_heads = self.vision_args.num_attention_heads dim1 = param.shape[0] # num_heads * head_dim @@ -612,7 +599,6 @@ def llm_weights_generator(): else: # LLM weights: yield them to be loaded # by language_model.load_weights - # self.logger.debug(f"Yielding weight {name}; shape {w.shape}") yield (name, w) # Now we call the language model load with the generator @@ -866,12 +852,9 @@ def forward( all tokens of all images of shape (N_toks, D) """ # pass images through initial convolution independently - logger.debug(f"patch_conv weights sample (first filter, channel 0): {self.patch_conv.weight[0, 0, :5, :5]}") patch_embeds_list = [ self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images ] - logger.debug(f"Raw patch conv output shape: {patch_embeds_list[0].shape}") - logger.debug(f"Raw patch conv output sample: {patch_embeds_list[0][0, :5, 0, :5]}") patch_embeds = [ p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list @@ -881,13 +864,9 @@ def forward( # flatten to a single sequence patch_embeds = torch.cat(patch_embeds, dim=1) # _ = self.ln_pre.weight.data.fill_(1.0) - logger.debug(f"ln_pre weight sample: {self.ln_pre.weight[:5]}") # logger.debug("Skipping ln_pre for now ...") patch_embeds = self.ln_pre(patch_embeds) - logger.debug(f"Patch embeddings shape after conv: {patch_embeds.shape}") - logger.debug(f"Patch embeddings sample: {patch_embeds[0, :5, :5]}") - # positional embeddings positions = position_meshgrid(patch_embeds_list).to(self.device) freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] From 39ac8163220b451a051d6b5a679b6955fe6bcf5d Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Thu, 27 Mar 2025 06:09:36 -0600 Subject: [PATCH 06/12] Use os.getenv instead of os.environ --- vllm/model_executor/models/pixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index c839df8dfe74..becc827fdb68 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -56,7 +56,7 @@ import logging import os logger = logging.getLogger(__name__) -logger.setLevel(os.environ["VLLM_LOGGING_LEVEL"]) +logger.setLevel(os.getenv("VLLM_LOGGING_LEVEL", "INFO")) try: from xformers import ops as xops From 7bb4b13bb79fd8b398b5ca590c1096d3c4ff9138 Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Thu, 3 Apr 2025 21:56:50 -0600 Subject: [PATCH 07/12] QuantConfig=None on multi_modal_projector --- vllm/model_executor/models/mistral3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 88c3823eaa19..7febe60795a3 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -120,13 +120,13 @@ def __init__(self, self.linear_1 = ColumnParallelLinear(vision_hidden_size, text_hidden_size, bias=multimodal_projector_bias, - quant_config=quant_config, + quant_config=None, prefix=f"{prefix}.linear_1") self.act = get_act_fn(projector_hidden_act) self.linear_2 = RowParallelLinear(text_hidden_size, text_hidden_size, bias=multimodal_projector_bias, - quant_config=quant_config, + quant_config=None, prefix=f"{prefix}.linear_2") def forward(self, image_features: torch.Tensor, From 68d81a9e886ed8f7157ad0ab1a34a784556876ee Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Mon, 7 Apr 2025 20:46:06 -0600 Subject: [PATCH 08/12] Revert "Mistral3.1: silu->GELU to match transformers definition" This reverts commit efee9cc2df2614595f147a1bcf3cfadd997ffae4. --- vllm/model_executor/models/pixtral.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index becc827fdb68..346e180adf51 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -697,10 +697,9 @@ def __init__(self, args: VisionEncoderArgs): self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) - self.act_fn = nn.GELU() def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(self.act_fn(self.w1(x)) * self.w3(x)) + return self.w2(F.silu(self.w1(x)) * self.w3(x)) class Attention(nn.Module): From 97acc5ee55de58239ff19dfb5a6cb100a722153e Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Thu, 5 Jun 2025 00:59:15 -0600 Subject: [PATCH 09/12] Make Pixtral work again after typing updates --- vllm/model_executor/models/pixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 346e180adf51..56523214cdcc 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -518,7 +518,7 @@ def compute_logits( r"^language_model\.model\.layers\.(\d+)\.(.+)\.(g_idx|zp|scales|zeros|qweight|qzeros)$": r"layers.\1.\2.\3" } - def maybe_remap_mistral3(self, name: str, tensor: torch.Tensor) -> Tuple[str, torch.Tensor]: + def maybe_remap_mistral3(self, name: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]: """Remap HF-style weight names back to original Pixtral format.""" self.logger.debug(f"Considering {name}") From c5d28d08f19267602ba5b8864e3c1dc272ad4863 Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Fri, 4 Jul 2025 00:16:46 -0600 Subject: [PATCH 10/12] Use `init_logger` per Gemini suggestion Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- vllm/model_executor/models/pixtral.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 56523214cdcc..b0ee259daeb5 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -53,7 +53,9 @@ merge_multimodal_embeddings) from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs -import logging +from vllm.logger import init_logger + +logger = init_logger(__name__) import os logger = logging.getLogger(__name__) logger.setLevel(os.getenv("VLLM_LOGGING_LEVEL", "INFO")) From 36a2846d9c6307056dcc01f8ea6978ac463e71d8 Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Fri, 4 Jul 2025 00:26:23 -0600 Subject: [PATCH 11/12] Clean up debug logs; rely on `logger_init` for logger. --- vllm/model_executor/models/pixtral.py | 30 +++++++-------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index b0ee259daeb5..3c020e567aad 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import re import math +import re from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from functools import cached_property @@ -26,6 +26,7 @@ from vllm.config import VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -53,12 +54,7 @@ merge_multimodal_embeddings) from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs -from vllm.logger import init_logger - logger = init_logger(__name__) -import os -logger = logging.getLogger(__name__) -logger.setLevel(os.getenv("VLLM_LOGGING_LEVEL", "INFO")) try: from xformers import ops as xops @@ -90,7 +86,6 @@ def __init__(self, tokenizer: MistralTokenizer) -> None: super().__init__() self.tokenizer = tokenizer - self.logger: logging.Logger = logger @property def image_processor(self) -> ImageEncoder: @@ -351,8 +346,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.multimodal_config = multimodal_config - self.logger: logging.Logger = logger - dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} vision_args = { key: value @@ -376,13 +369,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): eps=1e-5) if self.vision_args.mm_projector_id == PATCH_MERGE: - self.logger.debug("PatchMerger initalizing ...") self.patch_merger = PatchMerger( vision_encoder_dim=self.vision_args.hidden_size, spatial_merge_size=self.vision_args.spatial_merge_size, use_mlp_bias=False, ) - self.logger.debug("PatchMerger:\n\t%s", self.patch_merger) self.vision_language_adapter = VisionLanguageAdapter( self.vision_args, dim=config.text_config.hidden_size) @@ -522,12 +513,11 @@ def compute_logits( def maybe_remap_mistral3(self, name: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]: """Remap HF-style weight names back to original Pixtral format.""" - self.logger.debug(f"Considering {name}") for pattern, replacement in self.MISTRAL3_REVERSE_MAPPING.items(): new_name, n_replace = re.subn(pattern, replacement, name) if n_replace > 0: - self.logger.debug(f"Remapped {name} to {new_name}") + logger.debug(f"remapped %s to %s for Pixtral compat", name, new_name) return new_name, tensor return name, tensor # Return unchanged if no match @@ -570,14 +560,14 @@ def llm_weights_generator(): # Load vision encoder weights directly trimmed_name = '.'.join(name.split(".")[1:]) param = vision_encoder_dict[trimmed_name] - if trimmed_name.startswith("ln_pre"): - logger.debug("loading ln_pre weight now ...") if "wq.weight" in trimmed_name or "wk.weight" in trimmed_name: n_heads = self.vision_args.num_attention_heads dim1 = param.shape[0] # num_heads * head_dim dim2 = param.shape[1] # hidden_size w = inverse_permute_for_rope(w, n_heads, dim1, dim2) - logger.debug(f"Reversed permute_for_rope for {name}, sample: {w[:5, :5]}") + logger.debug( + "reversed permute_for_rope for %s", name + ) with torch.no_grad(): default_weight_loader(param, w) elif is_patch_merger((name, w)): @@ -864,8 +854,6 @@ def forward( # flatten to a single sequence patch_embeds = torch.cat(patch_embeds, dim=1) - # _ = self.ln_pre.weight.data.fill_(1.0) - # logger.debug("Skipping ln_pre for now ...") patch_embeds = self.ln_pre(patch_embeds) # positional embeddings @@ -916,12 +904,10 @@ def __init__( super().__init__() mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2) - - print("mlp_input_dim = {vision_encoder_dim} * ({spatial_merge_size}**2)") - print(f"mlp_input_dim = {vision_encoder_dim} * ({spatial_merge_size}**2)") - self.spatial_merge_size = spatial_merge_size self.mlp_input_dim = mlp_input_dim + logger.debug("mlp_input_dim = %d (from %d * (%d ** 2))", mlp_input_dim, + vision_encoder_dim, spatial_merge_size) self.merging_layer = nn.Linear( mlp_input_dim, From f277cf0f49caa4349a7b6f4baf88f11c18c42dd9 Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Fri, 4 Jul 2025 00:48:50 -0600 Subject: [PATCH 12/12] Apply yapf formatting via pre-commit --- vllm/model_executor/models/pixtral.py | 84 +++++++++++++++++---------- 1 file changed, 54 insertions(+), 30 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 3c020e567aad..ec0e40e29373 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -339,13 +339,14 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: raise ValueError("Only image modality is supported") packed_modules_mapping = {} + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config - + dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} vision_args = { key: value @@ -487,42 +488,66 @@ def compute_logits( # Reverse mapping from HF to original Pixtral format MISTRAL3_REVERSE_MAPPING = { - r"^language_model\.lm_head\.weight": r"output.weight", - r"^language_model\.model\.norm\.weight": r"norm.weight", - r"^language_model\.model\.embed_tokens\.weight": r"tok_embeddings.weight", - r"^language_model\.model\.layers\.(\d+)\.input_layernorm\.weight": r"layers.\1.attention_norm.weight", - r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm\.weight": r"layers.\1.ffn_norm.weight", - r"^language_model\.model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj\.weight": r"layers.\1.attention.w\2.weight", - r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.weight": r"layers.\1.feed_forward.w1.weight", - r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.weight": r"layers.\1.feed_forward.w2.weight", - r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.weight": r"layers.\1.feed_forward.w3.weight", - r"^vision_tower\.transformer\.layers\.(\d+)\.attention_norm\.weight": r"vision_encoder.transformer.layers.\1.attention_norm.weight", - r"^vision_tower\.transformer\.layers\.(\d+)\.ffn_norm\.weight": r"vision_encoder.transformer.layers.\1.ffn_norm.weight", - r"^vision_tower\.transformer\.layers\.(\d+)\.attention\.(q|k|v|o)_proj\.weight": r"vision_encoder.transformer.layers.\1.attention.w\2.weight", - r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.gate_proj\.weight": r"vision_encoder.transformer.layers.\1.feed_forward.w1.weight", - r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.down_proj\.weight": r"vision_encoder.transformer.layers.\1.feed_forward.w2.weight", - r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.up_proj\.weight": r"vision_encoder.transformer.layers.\1.feed_forward.w3.weight", - r"^multi_modal_projector\.linear_1": r"vision_language_adapter.w_in", - r"^multi_modal_projector\.linear_2": r"vision_language_adapter.w_out", - r"^vision_tower\.ln_pre\.weight": r"vision_encoder.ln_pre.weight", - r"^vision_tower\.patch_conv\.weight": r"vision_encoder.patch_conv.weight", - r"^multi_modal_projector\.patch_merger\.merging_layer\.weight": r"patch_merger.merging_layer.weight", - r"^multi_modal_projector\.norm\.weight": r"pre_mm_projector_norm.weight", - r"^language_model\.model\.layers\.(\d+)\.(.+)\.(g_idx|zp|scales|zeros|qweight|qzeros)$": r"layers.\1.\2.\3" + r"^language_model\.lm_head\.weight": + r"output.weight", + r"^language_model\.model\.norm\.weight": + r"norm.weight", + r"^language_model\.model\.embed_tokens\.weight": + r"tok_embeddings.weight", + r"^language_model\.model\.layers\.(\d+)\.input_layernorm\.weight": + r"layers.\1.attention_norm.weight", + r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm\.weight": + r"layers.\1.ffn_norm.weight", + r"^language_model\.model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj\.weight": + r"layers.\1.attention.w\2.weight", + r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.weight": + r"layers.\1.feed_forward.w1.weight", + r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.weight": + r"layers.\1.feed_forward.w2.weight", + r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.weight": + r"layers.\1.feed_forward.w3.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.attention_norm\.weight": + r"vision_encoder.transformer.layers.\1.attention_norm.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.ffn_norm\.weight": + r"vision_encoder.transformer.layers.\1.ffn_norm.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.attention\.(q|k|v|o)_proj\.weight": + r"vision_encoder.transformer.layers.\1.attention.w\2.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.gate_proj\.weight": + r"vision_encoder.transformer.layers.\1.feed_forward.w1.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.down_proj\.weight": + r"vision_encoder.transformer.layers.\1.feed_forward.w2.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.up_proj\.weight": + r"vision_encoder.transformer.layers.\1.feed_forward.w3.weight", + r"^multi_modal_projector\.linear_1": + r"vision_language_adapter.w_in", + r"^multi_modal_projector\.linear_2": + r"vision_language_adapter.w_out", + r"^vision_tower\.ln_pre\.weight": + r"vision_encoder.ln_pre.weight", + r"^vision_tower\.patch_conv\.weight": + r"vision_encoder.patch_conv.weight", + r"^multi_modal_projector\.patch_merger\.merging_layer\.weight": + r"patch_merger.merging_layer.weight", + r"^multi_modal_projector\.norm\.weight": + r"pre_mm_projector_norm.weight", + r"^language_model\.model\.layers\.(\d+)\.(.+)\.(g_idx|zp|scales|zeros|qweight|qzeros)$": + r"layers.\1.\2.\3" } - def maybe_remap_mistral3(self, name: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]: + def maybe_remap_mistral3(self, name: str, + tensor: torch.Tensor) -> tuple[str, torch.Tensor]: """Remap HF-style weight names back to original Pixtral format.""" for pattern, replacement in self.MISTRAL3_REVERSE_MAPPING.items(): new_name, n_replace = re.subn(pattern, replacement, name) if n_replace > 0: - logger.debug(f"remapped %s to %s for Pixtral compat", name, new_name) + logger.debug("remapped %s to %s for Pixtral compat", name, + new_name) return new_name, tensor return name, tensor # Return unchanged if no match def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - + def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): return weight[0].startswith("vision_encoder") @@ -554,7 +579,8 @@ def inverse_permute_for_rope(tensor, n_heads, dim1, dim2): def llm_weights_generator(): # Single pass over weights - remapped_weights = (self.maybe_remap_mistral3(name, w) for name, w in weights) + remapped_weights = (self.maybe_remap_mistral3(name, w) + for name, w in weights) for name, w in remapped_weights: if is_vision_encoder_weights((name, w)): # Load vision encoder weights directly @@ -565,9 +591,7 @@ def llm_weights_generator(): dim1 = param.shape[0] # num_heads * head_dim dim2 = param.shape[1] # hidden_size w = inverse_permute_for_rope(w, n_heads, dim1, dim2) - logger.debug( - "reversed permute_for_rope for %s", name - ) + logger.debug("reversed permute_for_rope for %s", name) with torch.no_grad(): default_weight_loader(param, w) elif is_patch_merger((name, w)):