Skip to content

[Model] Pooling models default to using chunked prefill & prefix caching if supported. #20930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
002b4cf
+ default_pooling_type
noooop Aug 5, 2025
4d21759
fix
noooop Aug 5, 2025
964560b
fix
noooop Aug 5, 2025
32411ce
fix pooling
noooop Aug 5, 2025
a97a004
turn off encode
noooop Aug 7, 2025
3f70b32
conflicts
noooop Aug 7, 2025
2b28a50
Merge branch 'vllm-project:main' into auto_conversion
noooop Aug 7, 2025
384f406
fix
noooop Aug 7, 2025
8a94d1c
update
noooop Aug 7, 2025
f9d7017
conflicts
noooop Aug 7, 2025
1b28912
Merge branch 'vllm-project:main' into auto_conversion
noooop Aug 7, 2025
6638ae0
fix
noooop Aug 7, 2025
bf14fc4
+ tests
noooop Aug 7, 2025
bc2753c
fix
noooop Aug 7, 2025
fef73ea
fix
noooop Aug 7, 2025
8bb0c0c
Merge branch 'vllm-project:main' into auto_conversion
noooop Aug 7, 2025
ec212d7
supported_tasks.remove("encode")
noooop Aug 8, 2025
a8ed919
fix
noooop Aug 8, 2025
bef21c5
Merge branch 'vllm-project:main' into auto_conversion
noooop Aug 8, 2025
7b4277f
set model_config.supported_tasks inside model runner
noooop Aug 8, 2025
1253f07
fix
noooop Aug 8, 2025
0e29a79
fix
noooop Aug 8, 2025
b6933fc
fix
noooop Aug 8, 2025
37b6827
fix
noooop Aug 8, 2025
2d3fa37
fix
noooop Aug 8, 2025
f42ab13
fix
noooop Aug 8, 2025
d80582a
+ logger.info
noooop Aug 8, 2025
8db0205
fix
noooop Aug 8, 2025
568ed63
fix
noooop Aug 8, 2025
9cd466d
logger.info in runner
noooop Aug 8, 2025
e988353
logger.info in runner
noooop Aug 8, 2025
58038ab
conflicts
noooop Aug 9, 2025
998e9cb
conflicts
noooop Aug 9, 2025
0d4dc95
Merge branch 'vllm-project:main' into auto_conversion
noooop Aug 9, 2025
906824e
add back
noooop Aug 9, 2025
6966f30
Merge branch 'vllm-project:main' into auto_conversion
noooop Aug 10, 2025
2af0c78
Merge branch 'vllm-project:main' into auto_conversion
noooop Aug 10, 2025
cbda86d
conflicts
noooop Aug 11, 2025
63b03cc
conflicts
noooop Aug 11, 2025
c36b7b0
Merge branch 'vllm-project:main' into auto_conversion
noooop Aug 11, 2025
135dffa
add back
noooop Aug 11, 2025
9552a49
fix
noooop Aug 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tests/entrypoints/llm/test_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,9 @@ def get_outputs(activation):
assert torch.allclose(
softmax(wo_activation), w_activation, atol=1e-2
), "w_activation should be close to activation(wo_activation)."


def test_encode_api(llm: LLM):
err_msg = "pooling_task must be one of.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompts, use_tqdm=False)
15 changes: 15 additions & 0 deletions tests/entrypoints/openai/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,18 @@ async def get_outputs(activation):
assert torch.allclose(
F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2
), "w_activation should be close to activation(wo_activation)."


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_pooling(server: RemoteOpenAIServer, model_name: str):
# pooling api uses ALL pooling, which does not support chunked prefill.
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": "test",
"encoding_format": "float"
},
)
assert response.json()["error"]["type"] == "BadRequestError"
12 changes: 10 additions & 2 deletions tests/models/language/pooling/mteb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,12 @@ def mteb_test_embed_models(hf_runner,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:

model_config = vllm_model.llm.llm_engine.model_config

if model_info.architecture:
assert (model_info.architecture
in vllm_model.llm.llm_engine.model_config.architectures)
assert model_info.architecture in model_config.architectures
assert (model_config._model_info.default_pooling_type ==
model_info.default_pooling_type)

vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
MTEB_EMBED_TASKS)
Expand Down Expand Up @@ -285,7 +288,12 @@ def mteb_test_rerank_models(hf_runner,
**vllm_extra_kwargs) as vllm_model:

model_config = vllm_model.llm.llm_engine.model_config

if model_info.architecture:
assert (model_info.architecture in model_config.architectures)
assert model_config.hf_config.num_labels == 1
assert (model_config._model_info.default_pooling_type ==
model_info.default_pooling_type)

vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model),
tasks=MTEB_RERANK_TASKS,
Expand Down
93 changes: 93 additions & 0 deletions tests/models/language/pooling/test_auto_prefix_cache_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from transformers import AutoModelForSequenceClassification

from tests.models.language.pooling.embed_utils import (
run_embedding_correctness_test)


@pytest.mark.parametrize(
"model",
["jason9693/Qwen2.5-1.5B-apeach"],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_classify_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:

example_prompts = example_prompts * 2

with vllm_runner(model,
max_model_len=512,
dtype=dtype,
enable_prefix_caching=True) as vllm_model:
cache_config = vllm_model.llm.llm_engine.cache_config
assert cache_config.enable_prefix_caching
vllm_outputs = vllm_model.classify(example_prompts)

with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts)

for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)

assert torch.allclose(hf_output, vllm_output,
1e-3 if dtype == "float" else 1e-2)


@pytest.mark.parametrize(
"model",
["Qwen/Qwen3-Embedding-0.6B"],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_embed_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
):
example_prompts = [str(s).strip() for s in example_prompts] * 2

with vllm_runner(
model,
runner="pooling",
max_model_len=None,
enable_prefix_caching=True,
) as vllm_model:
cache_config = vllm_model.llm.llm_engine.cache_config
assert cache_config.enable_prefix_caching
vllm_outputs = vllm_model.embed(example_prompts)

with hf_runner(
model,
is_sentence_transformer=True,
) as hf_model:
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)


@pytest.mark.parametrize(
"model",
[
"intfloat/e5-small",
"Alibaba-NLP/gte-Qwen2-1.5B-instruct", # is_causal == False
"papluca/xlm-roberta-base-language-detection",
])
@pytest.mark.parametrize("dtype", ["half"])
def test_non_causal_models(hf_runner, vllm_runner, example_prompts, model: str,
dtype: str) -> None:
with vllm_runner(model,
max_model_len=512,
dtype=dtype,
enable_prefix_caching=True) as vllm_model:
cache_config = vllm_model.llm.llm_engine.cache_config
assert not cache_config.enable_prefix_caching
117 changes: 61 additions & 56 deletions tests/models/language/pooling/test_baai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,73 +2,78 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest

from ...utils import EmbedModelInfo, RerankModelInfo
from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
EmbedModelInfo, LASTPoolingEmbedModelInfo,
RerankModelInfo)
from .embed_utils import correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models

MODELS = [
########## BertModel
EmbedModelInfo("BAAI/bge-base-en",
architecture="BertModel",
enable_test=True),
EmbedModelInfo("BAAI/bge-base-zh",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-small-en",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-small-zh",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-large-en",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-large-zh",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-large-zh-noinstruct",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-base-en-v1.5",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-base-zh-v1.5",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-small-en-v1.5",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-small-zh-v1.5",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-large-en-v1.5",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-large-zh-v1.5",
architecture="BertModel",
enable_test=False),
CLSPoolingEmbedModelInfo("BAAI/bge-base-en",
architecture="BertModel",
enable_test=True),
CLSPoolingEmbedModelInfo("BAAI/bge-base-zh",
architecture="BertModel",
enable_test=False),
CLSPoolingEmbedModelInfo("BAAI/bge-small-en",
architecture="BertModel",
enable_test=False),
CLSPoolingEmbedModelInfo("BAAI/bge-small-zh",
architecture="BertModel",
enable_test=False),
CLSPoolingEmbedModelInfo("BAAI/bge-large-en",
architecture="BertModel",
enable_test=False),
CLSPoolingEmbedModelInfo("BAAI/bge-large-zh",
architecture="BertModel",
enable_test=False),
CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-noinstruct",
architecture="BertModel",
enable_test=False),
CLSPoolingEmbedModelInfo("BAAI/bge-base-en-v1.5",
architecture="BertModel",
enable_test=False),
CLSPoolingEmbedModelInfo("BAAI/bge-base-zh-v1.5",
architecture="BertModel",
enable_test=False),
CLSPoolingEmbedModelInfo("BAAI/bge-small-en-v1.5",
architecture="BertModel",
enable_test=False),
CLSPoolingEmbedModelInfo("BAAI/bge-small-zh-v1.5",
architecture="BertModel",
enable_test=False),
CLSPoolingEmbedModelInfo("BAAI/bge-large-en-v1.5",
architecture="BertModel",
enable_test=False),
CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-v1.5",
architecture="BertModel",
enable_test=False),
########## XLMRobertaModel
EmbedModelInfo("BAAI/bge-m3",
architecture="XLMRobertaModel",
enable_test=True),
CLSPoolingEmbedModelInfo("BAAI/bge-m3",
architecture="XLMRobertaModel",
enable_test=True),
########## Qwen2Model
EmbedModelInfo("BAAI/bge-code-v1",
architecture="Qwen2Model",
dtype="float32",
enable_test=True),
LASTPoolingEmbedModelInfo("BAAI/bge-code-v1",
architecture="Qwen2Model",
dtype="float32",
enable_test=True),
]

RERANK_MODELS = [
########## XLMRobertaForSequenceClassification
RerankModelInfo("BAAI/bge-reranker-base",
architecture="XLMRobertaForSequenceClassification",
enable_test=True),
RerankModelInfo("BAAI/bge-reranker-large",
architecture="XLMRobertaForSequenceClassification",
enable_test=False),
RerankModelInfo("BAAI/bge-reranker-v2-m3",
architecture="XLMRobertaForSequenceClassification",
enable_test=False)
CLSPoolingRerankModelInfo(
"BAAI/bge-reranker-base",
architecture="XLMRobertaForSequenceClassification",
enable_test=True),
CLSPoolingRerankModelInfo(
"BAAI/bge-reranker-large",
architecture="XLMRobertaForSequenceClassification",
enable_test=False),
CLSPoolingRerankModelInfo(
"BAAI/bge-reranker-v2-m3",
architecture="XLMRobertaForSequenceClassification",
enable_test=False)
]


Expand Down
8 changes: 4 additions & 4 deletions tests/models/language/pooling/test_bge_reranker_v2_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from tests.conftest import HfRunner

from .mteb_utils import (RerankModelInfo, VllmMtebEncoder,
mteb_test_rerank_models)
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models

RERANK_MODELS = [
RerankModelInfo("BAAI/bge-reranker-v2-gemma",
architecture="GemmaForSequenceClassification"),
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
architecture="GemmaForSequenceClassification"),
]

PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
Expand Down
12 changes: 7 additions & 5 deletions tests/models/language/pooling/test_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest

from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
from ...utils import (CLSPoolingRerankModelInfo, LASTPoolingRerankModelInfo,
RerankModelInfo)
from .mteb_utils import mteb_test_rerank_models

RERANK_MODELS = [
RerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2",
architecture="BertForSequenceClassification"),
RerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
architecture="Qwen3ForSequenceClassification")
CLSPoolingRerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2",
architecture="BertForSequenceClassification"),
LASTPoolingRerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
architecture="Qwen3ForSequenceClassification")
]


Expand Down
Loading
Loading