Skip to content

Commit 93a126f

Browse files
[Misc] Make cached tokenizer pickle-compatible (#17048)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 8e4b351 commit 93a126f

File tree

5 files changed

+80
-56
lines changed

5 files changed

+80
-56
lines changed

benchmarks/benchmark_prefix_caching.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,16 @@ class Request:
6363
output_len: int
6464

6565

66-
def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str:
66+
def sample_tokens(tokenizer: PreTrainedTokenizerBase,
67+
length: int) -> list[int]:
6768
vocab = tokenizer.get_vocab()
69+
all_special_ids = set(tokenizer.all_special_ids)
70+
6871
# Remove the special tokens.
69-
vocab = {
70-
k: v
71-
for k, v in vocab.items() if k not in tokenizer.all_special_ids
72-
}
73-
return random.choices(list(vocab.values()), k=length)
72+
return random.choices(
73+
[v for k, v in vocab.items() if k not in all_special_ids],
74+
k=length,
75+
)
7476

7577

7678
def sample_requests_from_dataset(
+31-12
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,43 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
2+
import pickle
33
from copy import deepcopy
44

5+
import pytest
56
from transformers import AutoTokenizer
67

7-
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
8+
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
9+
get_cached_tokenizer)
810

911

10-
def test_cached_tokenizer():
11-
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
12+
@pytest.mark.parametrize("model_id", ["gpt2", "THUDM/chatglm3-6b"])
13+
def test_cached_tokenizer(model_id: str):
14+
reference_tokenizer = AutoTokenizer.from_pretrained(model_id,
15+
trust_remote_code=True)
1216
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
1317
reference_tokenizer.add_special_tokens(
1418
{"additional_special_tokens": ["<SEP>"]})
19+
1520
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
21+
_check_consistency(cached_tokenizer, reference_tokenizer)
22+
23+
pickled_tokenizer = pickle.dumps(cached_tokenizer)
24+
unpickled_tokenizer = pickle.loads(pickled_tokenizer)
25+
_check_consistency(unpickled_tokenizer, reference_tokenizer)
26+
27+
28+
def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer):
29+
assert isinstance(target, type(expected))
30+
31+
# Cached attributes
32+
assert target.all_special_ids == expected.all_special_ids
33+
assert target.all_special_tokens == expected.all_special_tokens
34+
assert (target.all_special_tokens_extended ==
35+
expected.all_special_tokens_extended)
36+
assert target.get_vocab() == expected.get_vocab()
37+
assert len(target) == len(expected)
38+
39+
# Other attributes
40+
assert getattr(target, "padding_side",
41+
None) == getattr(expected, "padding_side", None)
1642

17-
assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode(
18-
"prompt")
19-
assert set(reference_tokenizer.all_special_ids) == set(
20-
cached_tokenizer.all_special_ids)
21-
assert set(reference_tokenizer.all_special_tokens) == set(
22-
cached_tokenizer.all_special_tokens)
23-
assert set(reference_tokenizer.all_special_tokens_extended) == set(
24-
cached_tokenizer.all_special_tokens_extended)
43+
assert target.encode("prompt") == expected.encode("prompt")

vllm/transformers_utils/tokenizer.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import contextlib
4+
import copy
45
import os
56
import warnings
67
from functools import lru_cache
@@ -70,18 +71,17 @@ def encode_tokens(
7071

7172

7273
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
73-
"""Get tokenizer with cached properties.
74-
75-
This will patch the tokenizer object in place.
76-
74+
"""
7775
By default, transformers will recompute multiple tokenizer properties
78-
each time they are called, leading to a significant slowdown. This
79-
function caches these properties for faster access."""
76+
each time they are called, leading to a significant slowdown.
77+
This proxy caches these properties for faster access.
78+
"""
79+
cached_tokenizer = copy.copy(tokenizer)
8080

81-
tokenizer_all_special_ids = set(tokenizer.all_special_ids)
81+
tokenizer_all_special_ids = tokenizer.all_special_ids
82+
tokenizer_all_special_tokens = tokenizer.all_special_tokens
8283
tokenizer_all_special_tokens_extended = (
8384
tokenizer.all_special_tokens_extended)
84-
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
8585
tokenizer_vocab = tokenizer.get_vocab()
8686
tokenizer_len = len(tokenizer)
8787

@@ -97,31 +97,34 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
9797
class CachedTokenizer(tokenizer.__class__): # type: ignore
9898

9999
@property
100-
def all_special_ids(self):
100+
def all_special_ids(self) -> list[int]:
101101
return tokenizer_all_special_ids
102102

103103
@property
104-
def all_special_tokens(self):
104+
def all_special_tokens(self) -> list[str]:
105105
return tokenizer_all_special_tokens
106106

107107
@property
108-
def all_special_tokens_extended(self):
108+
def all_special_tokens_extended(self) -> list[str]:
109109
return tokenizer_all_special_tokens_extended
110110

111111
@property
112-
def max_token_id(self):
112+
def max_token_id(self) -> int:
113113
return max_token_id
114114

115-
def get_vocab(self):
115+
def get_vocab(self) -> dict[str, int]:
116116
return tokenizer_vocab
117117

118-
def __len__(self):
118+
def __len__(self) -> int:
119119
return tokenizer_len
120120

121+
def __reduce__(self):
122+
return get_cached_tokenizer, (tokenizer, )
123+
121124
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
122125

123-
tokenizer.__class__ = CachedTokenizer
124-
return tokenizer
126+
cached_tokenizer.__class__ = CachedTokenizer
127+
return cached_tokenizer
125128

126129

127130
def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None:

vllm/transformers_utils/tokenizer_base.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import importlib
44
from abc import ABC, abstractmethod
5-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
5+
from typing import TYPE_CHECKING, Any, Optional, Union
66

77
if TYPE_CHECKING:
88
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
@@ -12,17 +12,17 @@ class TokenizerBase(ABC):
1212

1313
@property
1414
@abstractmethod
15-
def all_special_tokens_extended(self) -> List[str]:
15+
def all_special_tokens_extended(self) -> list[str]:
1616
raise NotImplementedError()
1717

1818
@property
1919
@abstractmethod
20-
def all_special_tokens(self) -> List[str]:
20+
def all_special_tokens(self) -> list[str]:
2121
raise NotImplementedError()
2222

2323
@property
2424
@abstractmethod
25-
def all_special_ids(self) -> List[int]:
25+
def all_special_ids(self) -> list[int]:
2626
raise NotImplementedError()
2727

2828
@property
@@ -66,7 +66,7 @@ def __len__(self) -> int:
6666
@abstractmethod
6767
def __call__(
6868
self,
69-
text: Union[str, List[str], List[int]],
69+
text: Union[str, list[str], list[int]],
7070
text_pair: Optional[str] = None,
7171
add_special_tokens: bool = False,
7272
truncation: bool = False,
@@ -75,11 +75,11 @@ def __call__(
7575
raise NotImplementedError()
7676

7777
@abstractmethod
78-
def get_vocab(self) -> Dict[str, int]:
78+
def get_vocab(self) -> dict[str, int]:
7979
raise NotImplementedError()
8080

8181
@abstractmethod
82-
def get_added_vocab(self) -> Dict[str, int]:
82+
def get_added_vocab(self) -> dict[str, int]:
8383
raise NotImplementedError()
8484

8585
@abstractmethod
@@ -88,44 +88,44 @@ def encode_one(
8888
text: str,
8989
truncation: bool = False,
9090
max_length: Optional[int] = None,
91-
) -> List[int]:
91+
) -> list[int]:
9292
raise NotImplementedError()
9393

9494
@abstractmethod
9595
def encode(self,
9696
text: str,
97-
add_special_tokens: Optional[bool] = None) -> List[int]:
97+
add_special_tokens: Optional[bool] = None) -> list[int]:
9898
raise NotImplementedError()
9999

100100
@abstractmethod
101101
def apply_chat_template(self,
102-
messages: List["ChatCompletionMessageParam"],
103-
tools: Optional[List[Dict[str, Any]]] = None,
104-
**kwargs) -> List[int]:
102+
messages: list["ChatCompletionMessageParam"],
103+
tools: Optional[list[dict[str, Any]]] = None,
104+
**kwargs) -> list[int]:
105105
raise NotImplementedError()
106106

107107
@abstractmethod
108-
def convert_tokens_to_string(self, tokens: List[str]) -> str:
108+
def convert_tokens_to_string(self, tokens: list[str]) -> str:
109109
raise NotImplementedError()
110110

111111
@abstractmethod
112112
def decode(self,
113-
ids: Union[List[int], int],
113+
ids: Union[list[int], int],
114114
skip_special_tokens: bool = True) -> str:
115115
raise NotImplementedError()
116116

117117
@abstractmethod
118118
def convert_ids_to_tokens(
119119
self,
120-
ids: List[int],
120+
ids: list[int],
121121
skip_special_tokens: bool = True,
122-
) -> List[str]:
122+
) -> list[str]:
123123
raise NotImplementedError()
124124

125125

126126
class TokenizerRegistry:
127127
# Tokenizer name -> (tokenizer module, tokenizer class)
128-
REGISTRY: Dict[str, Tuple[str, str]] = {}
128+
REGISTRY: dict[str, tuple[str, str]] = {}
129129

130130
@staticmethod
131131
def register(name: str, module: str, class_name: str) -> None:

vllm/transformers_utils/tokenizers/mistral.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
257257
# the following attributes are set to fit vLLM's design and are used
258258
# by the guided structured output backends.
259259
@property
260-
def all_special_tokens_extended(self) -> List[str]:
260+
def all_special_tokens_extended(self) -> list[str]:
261261
from mistral_common.tokens.tokenizers.base import SpecialTokens
262262

263263
# tekken defines its own extended special tokens list
@@ -271,11 +271,11 @@ def all_special_tokens_extended(self) -> List[str]:
271271
]
272272

273273
@property
274-
def all_special_tokens(self) -> List[str]:
274+
def all_special_tokens(self) -> list[str]:
275275
return self.all_special_tokens_extended
276276

277277
@property
278-
def all_special_ids(self) -> List[int]:
278+
def all_special_ids(self) -> list[int]:
279279
return [
280280
self.all_special_tokens.index(t) for t in self.all_special_tokens
281281
]
@@ -335,12 +335,12 @@ def __call__(
335335
input_ids = self.encode_one(text, truncation, max_length)
336336
return Encoding(input_ids=input_ids)
337337

338-
def get_vocab(self) -> Dict[str, int]:
338+
def get_vocab(self) -> dict[str, int]:
339339
# NB: the dictionary form of the vocabulary collapses token ids that map
340340
# to the same string but have different bytes
341341
return self._vocab_dict
342342

343-
def get_added_vocab(self) -> Dict[str, int]:
343+
def get_added_vocab(self) -> dict[str, int]:
344344
# Mistral tokenizers have no added vocabulary
345345
return {}
346346

0 commit comments

Comments
 (0)