Skip to content

v1: Add Whisper model support (encoder-decoder) #21088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
8a2c588
Support encoder-only models without KV-Cache
maxdebayser Jul 20, 2025
3f11075
Merge branch 'upstream_main' into v1_encoder_only
maxdebayser Jul 21, 2025
a416120
Merge branch 'upstream_main' into v1_encoder_only
maxdebayser Jul 21, 2025
d845e22
Merge branch 'upstream_main' into v1_encoder_only
maxdebayser Jul 21, 2025
1f3fcc4
address review comments
maxdebayser Jul 21, 2025
85bf5fe
Merge branch 'upstream_main' into v1_encoder_only
maxdebayser Jul 22, 2025
8e2cba1
remove sliding window attention case
maxdebayser Jul 22, 2025
7357614
address review comment
maxdebayser Jul 22, 2025
aa69e92
make causal a flag in common attention metadata
maxdebayser Jul 22, 2025
838567f
Merge branch 'upstream_main' into v1_encoder_only
maxdebayser Jul 23, 2025
d81e143
fix typo
maxdebayser Jul 23, 2025
f0caa0b
Merge branch 'upstream_main' into v1_encoder_only
maxdebayser Jul 23, 2025
9318e98
Merge branch 'upstream_main' into v1_encoder_only
maxdebayser Jul 24, 2025
068697b
remove encoder model from unsupported test
maxdebayser Jul 24, 2025
837e51b
fix apply_model tests
maxdebayser Jul 24, 2025
bec5419
Merge branch 'upstream_main' into v1_encoder_only
maxdebayser Jul 25, 2025
cbc2c4e
Try multiproc in test
maxdebayser Jul 25, 2025
6f64b11
remove quant code
maxdebayser Jul 25, 2025
e39bc74
address review comment
maxdebayser Jul 25, 2025
b406896
Merge branch 'upstream_main' into v1_encoder_only
maxdebayser Jul 25, 2025
01b2a3e
v1: Add Whisper model support (encoder-decoder)
russellb Jul 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/offline_inference/prithvi_geospatial_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import argparse
import datetime
import os
import re
from typing import Union

import albumentations
import numpy as np
import rasterio
import regex as re
import torch
from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
Expand Down
13 changes: 11 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,8 +1062,17 @@ def score(
return [req_output.outputs.score for req_output in req_outputs]

def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
executor = self.llm.llm_engine.model_executor
return executor.apply_model(func)
if hasattr(self.llm.llm_engine, "model_executor"):
# This works either in V0 or in V1 with
# VLLM_ENABLE_V1_MULTIPROCESSING=0
executor = self.llm.llm_engine.model_executor
return executor.apply_model(func)

# This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1
def _apply_model(self):
return func(self.get_model())

return self.llm.llm_engine.collective_rpc(_apply_model)

def __enter__(self):
return self
Expand Down
12 changes: 9 additions & 3 deletions tests/model_executor/test_model_load_with_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@

@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_model_loading_with_params(vllm_runner):
def test_model_loading_with_params(vllm_runner, monkeypatch):
"""
Test parameter weight loading with tp>1.
"""
# to use apply_model
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_name=MODEL_NAME,
revision=REVISION,
dtype="float16",
Expand Down Expand Up @@ -61,10 +63,12 @@ def check_model(model):

@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_roberta_model_loading_with_params(vllm_runner):
def test_roberta_model_loading_with_params(vllm_runner, monkeypatch):
"""
Test parameter weight loading with tp>1.
"""
# to use apply_model
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_name=MODEL_NAME_ROBERTA,
revision=REVISION_ROBERTA,
dtype="float16",
Expand Down Expand Up @@ -101,10 +105,12 @@ def check_model(model):

@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_facebook_roberta_model_loading_with_params(vllm_runner):
def test_facebook_roberta_model_loading_with_params(vllm_runner, monkeypatch):
"""
Test loading roberta-base model with no lm_head.
"""
# to use apply_model
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
model_name = "FacebookAI/roberta-base"
with vllm_runner(model_name=model_name,
dtype="float16",
Expand Down
14 changes: 3 additions & 11 deletions tests/models/language/pooling/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,9 @@ def v1(run_with_both_engines):
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]),
# [Encoder-only]
pytest.param(
"BAAI/bge-base-en-v1.5",
marks=[
# CPU only supports V1
pytest.mark.core_model,
pytest.mark.skip_v1
]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2",
marks=[pytest.mark.skip_v1]),
pytest.param("intfloat/multilingual-e5-small",
marks=[pytest.mark.skip_v1]),
pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
pytest.param("intfloat/multilingual-e5-small"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
marks=[pytest.mark.skip_v1]),
# [Cross-Encoder]
Expand Down
8 changes: 8 additions & 0 deletions tests/models/language/pooling/test_jina.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
]


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/v1/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def create_common_attn_metadata(
max_query_len=max_query_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
causal=True,
)


Expand Down
1 change: 0 additions & 1 deletion tests/v1/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"openai/whisper-large-v3", # transcription
"facebook/bart-large-cnn", # encoder decoder
"state-spaces/mamba-130m-hf", # mamba1
"BAAI/bge-m3", # embedding
]

MODEL = "meta-llama/Llama-3.2-1B-Instruct"
Expand Down
3 changes: 1 addition & 2 deletions tests/v1/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import re

import pytest
import regex as re
import requests
import torch

Expand Down
1 change: 0 additions & 1 deletion vllm/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"AttentionMetadata",
"AttentionType",
"AttentionMetadataBuilder",
"Attention",
"AttentionState",
"get_attn_backend",
]
3 changes: 2 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,7 +1649,8 @@ def _set_default_args_v1(self, usage_context: UsageContext,

if (self.max_num_seqs is None
and usage_context in default_max_num_seqs):
self.max_num_seqs = default_max_num_seqs[usage_context]
self.max_num_seqs = min(default_max_num_seqs[usage_context],
self.max_num_batched_tokens or sys.maxsize)

logger.debug("Setting max_num_seqs to %d for %s usage context.",
self.max_num_seqs, use_context_value)
Expand Down
6 changes: 0 additions & 6 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,9 +841,6 @@ def preprocess(
) -> ProcessorInputs:
"""Preprocess the input prompt."""
if self.model_config.is_encoder_decoder:
assert not return_mm_hashes, (
"Multimodal hashes for encoder-decoder models should not be ",
"returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return self._process_encoder_decoder_prompt(
Expand Down Expand Up @@ -873,9 +870,6 @@ async def preprocess_async(
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
"""
if self.model_config.is_encoder_decoder:
assert not return_mm_hashes, (
"Multimodal hashes for encoder-decoder models should not be ",
"returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async(prompt)
Expand Down
18 changes: 7 additions & 11 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
Expand Down Expand Up @@ -60,7 +59,6 @@ def __init__(self, config: BertConfig):
def forward(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Expand Down Expand Up @@ -119,7 +117,6 @@ def forward(
return pooled_output


@support_torch_compile
class BertEncoder(nn.Module):

def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
Expand Down Expand Up @@ -337,6 +334,7 @@ def forward(self, hidden_states: torch.Tensor,
return hidden_states


@support_torch_compile
class BertModel(nn.Module, SupportsQuant):

is_pooling_model = True
Expand Down Expand Up @@ -368,13 +366,9 @@ def forward(
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
attn_metadata = get_forward_context().attn_metadata
assert hasattr(attn_metadata, "seq_lens_tensor")
hidden_states = self.embeddings(
input_ids=input_ids,
seq_lens=attn_metadata.seq_lens_tensor,
position_ids=position_ids,
token_type_ids=token_type_ids)
hidden_states = self.embeddings(input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
return self.encoder(hidden_states)

def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
Expand Down Expand Up @@ -447,7 +441,7 @@ def load_weights(self, weights: Iterable[tuple[str,
return loaded_params


class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
class BertEmbeddingModel(nn.Module, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.

This class encapsulates the BertModel and provides an interface for
Expand All @@ -474,11 +468,13 @@ def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
position_ids=positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)

Expand Down
85 changes: 64 additions & 21 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from transformers import RobertaConfig

from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler)
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -51,33 +52,12 @@ def __init__(self, config: RobertaConfig):
def forward(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)

# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
seq_lens_list = seq_lens.tolist()
new_pos_list = []
for positions, tokens in zip(position_ids.split(seq_lens_list),
input_ids.split(seq_lens_list)):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
dtype=torch.long,
device=inputs_embeds.device)
assert torch.equal(positions, expected_pos)
new_pos_list.append(
create_position_ids_from_input_ids(tokens, self.padding_idx))
position_ids = torch.cat(new_pos_list)

# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
Expand Down Expand Up @@ -119,6 +99,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:

# Fix Roberta positions here outside of the CUDA graph.
# Because we need the to extract the sequences from
# input_ids the control flow is data dependent.
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)

return self.model(input_ids=input_ids,
position_ids=positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)

def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> Union[BertModel, BertWithRope]:
Expand Down Expand Up @@ -175,6 +181,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id

self.num_labels = config.num_labels
self.roberta = BertModel(vllm_config=vllm_config,
Expand Down Expand Up @@ -216,6 +223,9 @@ def forward(
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)
return self.roberta(input_ids=input_ids,
position_ids=positions,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -245,3 +255,36 @@ def create_position_ids_from_input_ids(input_ids,
past_key_values_length) * mask

return incremental_indices.long() + padding_idx


def replace_roberta_positions(input_ids: torch.Tensor,
position_ids: torch.Tensor,
padding_idx: int) -> None:

seq_lens: Optional[torch.Tensor] = None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None: # can be None during warmup
if isinstance(attn_metadata, dict):
attn_metadata = next(iter(attn_metadata.values()))
# TODO: remove "seq_lens_tensor" after V0 is removed
seq_lens = getattr(attn_metadata, "seq_lens_tensor",
getattr(attn_metadata, "seq_lens", None))

if seq_lens is not None:
assert isinstance(seq_lens, torch.Tensor)

# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
token_list = torch.split(input_ids[:torch.sum(seq_lens)],
seq_lens.tolist())

offset = 0
for tokens in token_list:
length = tokens.shape[0]
position_ids[offset:offset+length] = \
create_position_ids_from_input_ids(tokens, padding_idx)
offset = offset + length
Loading