Skip to content

Add support for token_type_ids #19988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5dee54d
Add support for encoder embedding models
maxdebayser Jun 23, 2025
7eb9d28
Fix CUDA graphs for BERT models
maxdebayser Jul 1, 2025
67691e0
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 1, 2025
d3099a9
Fix cuda graph initialization of token type ids
maxdebayser Jul 1, 2025
613ff3b
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 2, 2025
20c41e4
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 2, 2025
ba86026
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 8, 2025
b4f5ead
Fix missing args
maxdebayser Jul 9, 2025
c4060d1
relax assertion
maxdebayser Jul 9, 2025
01d2a65
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 9, 2025
80930d8
fix missing arg
maxdebayser Jul 9, 2025
d881f0a
fix missing arg
maxdebayser Jul 10, 2025
90a25d0
remove model from unsupported list
maxdebayser Jul 10, 2025
6686550
fix missing arg
maxdebayser Jul 10, 2025
cc76777
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 10, 2025
136c9b3
fix tests
maxdebayser Jul 10, 2025
b232491
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 14, 2025
cf5e6b8
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 16, 2025
e19c738
fix tests
maxdebayser Jul 16, 2025
e255f30
fix tests
maxdebayser Jul 16, 2025
ee5950c
add missing arg
maxdebayser Jul 16, 2025
78a2e57
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 16, 2025
a5cfc84
add missing arg
maxdebayser Jul 16, 2025
63fd783
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 16, 2025
f58692c
Merge branch 'main' into v1_embeddings_full
maxdebayser Jul 20, 2025
eea55fb
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 25, 2025
f2d8e18
Merge branch 'v1_embeddings_full' of github.com:maxdebayser/vllm into…
maxdebayser Jul 25, 2025
12ae080
revert attn changes to simplify merge
maxdebayser Jul 28, 2025
f29da32
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 28, 2025
f0c67f6
fix case of models without tokenizer
maxdebayser Jul 28, 2025
b62a51a
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 29, 2025
cc970ab
simplify score code
maxdebayser Jul 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions tests/models/language/pooling/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,13 @@ def v1(run_with_both_engines):
marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]),
# [Encoder-only]
pytest.param("BAAI/bge-base-en-v1.5",
marks=[
pytest.mark.core_model, pytest.mark.cpu_model,
pytest.mark.skip_v1
]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2",
marks=[pytest.mark.skip_v1]),
pytest.param("intfloat/multilingual-e5-small",
marks=[pytest.mark.skip_v1]),
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
pytest.param("intfloat/multilingual-e5-small"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
marks=[pytest.mark.skip_v1]),
marks=[pytest.mark.skip_v0]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider removing the skip mark for v0 tests, as it seems the models are now supported in both engines.

# [Cross-Encoder]
pytest.param("sentence-transformers/stsb-roberta-base-v2",
marks=[pytest.mark.skip_v1]),
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
],
)
def test_models(
Expand Down
8 changes: 8 additions & 0 deletions tests/models/language/pooling/test_jina.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
]


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
Expand Down
9 changes: 9 additions & 0 deletions tests/models/language/pooling/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
"The capital of Germany is Berlin.",
]


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


DTYPE = "half"


Expand Down
2 changes: 1 addition & 1 deletion tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,4 +916,4 @@ def test_get_kv_cache_config():
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())
])
])
5 changes: 5 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,11 @@ def _init_pooler_config(self) -> Optional["PoolerConfig"]:
self.override_pooler_config = PoolerConfig(
**self.override_pooler_config)

# WIP: currently cuda graphs are not working for encoder models.
logger.warning("CUDA graph is not supported for pooling yet, "
"fallback to the eager mode.")
self.enforce_eager = True

pooler_config = self.override_pooler_config or PoolerConfig()

base_config = get_pooling_config(self.model, self.revision)
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,7 +1664,8 @@

if (self.max_num_seqs is None
and usage_context in default_max_num_seqs):
self.max_num_seqs = default_max_num_seqs[usage_context]
self.max_num_seqs = min(default_max_num_seqs[usage_context],

Check failure on line 1667 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type variable "SupportsRichComparisonT" of "min" cannot be "int | None" [type-var]

Check failure on line 1667 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type variable "SupportsRichComparisonT" of "min" cannot be "int | None" [type-var]

Check failure on line 1667 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type variable "SupportsRichComparisonT" of "min" cannot be "int | None" [type-var]

Check failure on line 1667 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type variable "SupportsRichComparisonT" of "min" cannot be "int | None" [type-var]

Check failure on line 1667 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type variable "SupportsRichComparisonT" of "min" cannot be "int | None" [type-var]

Check failure on line 1667 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type variable "SupportsRichComparisonT" of "min" cannot be "int | None" [type-var]

Check failure on line 1667 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type variable "SupportsRichComparisonT" of "min" cannot be "Optional[int]" [type-var]

Check failure on line 1667 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type variable "SupportsRichComparisonT" of "min" cannot be "Optional[int]" [type-var]

Check failure on line 1667 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type variable "SupportsRichComparisonT" of "min" cannot be "Optional[int]" [type-var]

Check failure on line 1667 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type variable "SupportsRichComparisonT" of "min" cannot be "Optional[int]" [type-var]
self.max_num_batched_tokens)

logger.debug("Setting max_num_seqs to %d for %s usage context.",
self.max_num_seqs, use_context_value)
Expand Down
22 changes: 9 additions & 13 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
Expand All @@ -28,7 +27,7 @@
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)

from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .utils import WeightsMapper, maybe_prefix


Expand Down Expand Up @@ -57,7 +56,6 @@ def __init__(self, config: BertConfig):
def forward(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Expand Down Expand Up @@ -342,13 +340,9 @@ def forward(
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
attn_metadata = get_forward_context().attn_metadata
assert hasattr(attn_metadata, "seq_lens_tensor")
hidden_states = self.embeddings(
input_ids=input_ids,
seq_lens=attn_metadata.seq_lens_tensor,
position_ids=position_ids,
token_type_ids=token_type_ids)
hidden_states = self.embeddings(input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
return self.encoder(hidden_states)

def load_weights(self, weights: Iterable[tuple[str,
Expand Down Expand Up @@ -388,7 +382,7 @@ def load_weights(self, weights: Iterable[tuple[str,
return loaded_params


class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
class BertEmbeddingModel(nn.Module, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.

This class encapsulates the BertModel and provides an interface for
Expand All @@ -411,11 +405,13 @@ def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
position_ids=positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)

Expand Down Expand Up @@ -446,8 +442,8 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
softmax=False)


class BertForSequenceClassification(nn.Module, SupportsV0Only,
SupportsCrossEncoding, SupportsQuant):
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.

This class encapsulates the BertModel and provides an interface for
Expand Down
38 changes: 16 additions & 22 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
get_cross_encoder_activation_function)

from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .interfaces import SupportsCrossEncoding


class RobertaEmbedding(nn.Module):
Expand Down Expand Up @@ -52,41 +52,36 @@ def __init__(self, config: RobertaConfig):
def forward(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:

input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)

zero_pos = torch.where(position_ids == 0)[0]
end_pos = torch.cat((zero_pos[1:],
torch.tensor([position_ids.shape[0]],
device=zero_pos.device)))
seq_lens = end_pos - zero_pos

# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
token_list = torch.split(input_ids, seq_lens.tolist())

pos_list = []
token_list = []
offset = 0
for seq_len in seq_lens:
pos_list.append(position_ids[offset:offset + seq_len])
token_list.append(input_ids[offset:offset + seq_len])
offset += seq_len

new_pos_list = []
for positions, tokens in zip(pos_list, token_list):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
dtype=torch.long,
device=inputs_embeds.device)
assert torch.equal(positions, expected_pos)
new_pos_list.append(
for tokens in token_list:
pos_list.append(
create_position_ids_from_input_ids(tokens, self.padding_idx))
position_ids = torch.cat(new_pos_list)

corrected_positions = torch.cat(pos_list)

# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
position_embeddings = self.position_embeddings(corrected_positions)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
Expand Down Expand Up @@ -150,8 +145,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
assert len(loaded), "Unable to load RobertaEmbeddingModel"


class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsV0Only):
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities.

This class encapsulates the BertModel and provides an interface for
Expand Down
30 changes: 25 additions & 5 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,11 +386,13 @@ def __init__(
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
if attn_type not in [
AttentionType.DECODER, AttentionType.ENCODER_ONLY
]:
raise NotImplementedError("Encoder/decoder cross-attention "
"is not implemented for "
"FlashAttentionImpl")
self.attn_type = attn_type
self.use_irope = use_irope
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
Expand Down Expand Up @@ -509,7 +511,7 @@ def forward(
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
causal=_get_causal_option(self.attn_type),
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
Expand Down Expand Up @@ -711,3 +713,21 @@ def cascade_attention(
# Merge prefix and suffix outputs, and store the result in output.
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
suffix_lse)


def _get_causal_option(attn_type: str) -> bool:
"""
Determine whether the given attention type is suitable for causal
attention mechanisms.

Args:
attn_type (AttentionType): The type of attention being evaluated

Returns:
bool: Returns `True` if the attention type is suitable for causal
attention (i.e., not encoder, encoder-only, or encoder-decoder),
otherwise returns `False`.
"""
return not (attn_type == AttentionType.ENCODER
or attn_type == AttentionType.ENCODER_ONLY
or attn_type == AttentionType.ENCODER_DECODER)
1 change: 1 addition & 0 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,7 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
dtype=spec.dtype,
use_mla=spec.use_mla,
sliding_window=spec.sliding_window,
attn_type=str(spec.attn_type),
)

if is_hybrid(kv_cache_spec):
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class NewRequestData:

req_id: str
prompt_token_ids: list[int]
token_type_ids: Optional[list[int]]
mm_inputs: list[MultiModalKwargs]
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
Expand All @@ -42,6 +43,7 @@ def from_request(
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
token_type_ids=request.token_type_ids,
mm_inputs=request.mm_inputs,
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class EngineCoreRequest(

request_id: str
prompt_token_ids: list[int]
token_type_ids: Optional[list[int]]
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
Expand Down
19 changes: 18 additions & 1 deletion vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
Expand Down Expand Up @@ -150,6 +150,23 @@ def _initialize_kv_caches(
zip(kv_cache_specs, available_gpu_memory)
]

for kv_cache_spec_one_worker in kv_cache_specs:
for _, spec in kv_cache_spec_one_worker.items():
if isinstance(spec, AttentionSpec) and \
spec.attn_type != "decoder":

logger.info("Found non-decoder layer. Disabling "
"prefix cache and chunked prefill")
self.vllm_config.cache_config.\
enable_prefix_caching = False
self.vllm_config.scheduler_config.\
enable_chunked_prefill = False
self.vllm_config.scheduler_config.\
chunked_prefill_enabled = False
self.vllm_config.scheduler_config.\
long_prefill_token_threshold = 0
break

# Since we use a shared centralized controller, we need the
# `kv_cache_config` to be consistent across all workers to make sure
# all the memory operators can be applied to all workers.
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def process_inputs(
return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"],
token_type_ids=decoder_inputs.get("token_type_ids"),
mm_inputs=sorted_mm_inputs,
mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions,
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/kv_cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class AttentionSpec(KVCacheSpec):
head_size: int
dtype: torch.dtype
use_mla: bool
attn_type: str

@property
def page_size_bytes(self) -> int:
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
self,
request_id: str,
prompt_token_ids: list[int],
token_type_ids: Optional[list[int]],
multi_modal_inputs: Optional[list[MultiModalKwargs]],
multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list[PlaceholderRange]],
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(
"sampling_params and pooling_params can't both be unset")

self.prompt_token_ids = prompt_token_ids
self.token_type_ids = token_type_ids
self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: list[int] = []
self._all_token_ids: list[int] = self.prompt_token_ids.copy()
Expand Down Expand Up @@ -118,6 +120,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
request_id=request.request_id,
client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids,
token_type_ids=request.token_type_ids,
multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
Expand Down
Loading