Skip to content

[Feature]: Remove xformers requirement for Mistral-format Pixtral and Mistral3 #21062

@mgoin

Description

@mgoin

🚀 The feature, motivation and pitch

I implemented this a while ago for the HF-format of Pixtral in #9597 by using the torch SDPA implementation. Xformers is not available on all architectures and most other vision encoders have multiple backends for attention. Pixtral is maybe the only that uses xformers strictly.

We should be able to replace the xops usage in the pixtral.py classes VisionTransformer and Attention by following the same substitution as in the HF modules.
Such as

if USE_XFORMERS_OPS:
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
else:
from transformers.models.pixtral.modeling_pixtral import (
generate_block_attention_mask)
attention_mask = generate_block_attention_mask(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
patch_embeds)

and
if USE_XFORMERS_OPS:
# Transpose q and k back for attention
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
out = xops.memory_efficient_attention(q,
k,
v,
attn_bias=attention_mask)
else:
v = v.transpose(1, 2)
out = nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask)
out = out.transpose(1, 2)

Alternatives

No response

Additional context

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions