Skip to content

Commit 4b8f85c

Browse files
authored
Use token counts for SimilarLengthsBatchifyer (#155)
* use token counts for SimilarLengthsBatchifyer and use the batchifyer for dense embedding model too * make synchronization device agnostic
1 parent fc49f9d commit 4b8f85c

File tree

4 files changed

+281
-87
lines changed

4 files changed

+281
-87
lines changed

retrieval.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,21 @@
1212
from bs4 import BeautifulSoup
1313
from transformers import AutoTokenizer, AutoModelForMaskedLM
1414
import optimum.bettertransformer.transformation
15-
from sentence_transformers import SentenceTransformer
1615

1716
try:
1817
from .retrievers.faiss_retriever import FaissRetriever
1918
from .retrievers.bm25_retriever import BM25Retriever
2019
from .retrievers.splade_retriever import SpladeRetriever
2120
from .chunkers.semantic_chunker import BoundedSemanticChunker
2221
from .chunkers.character_chunker import RecursiveCharacterTextSplitter
23-
from .utils import Document
22+
from .utils import Document, MySentenceTransformer
2423
except ImportError:
2524
from retrievers.faiss_retriever import FaissRetriever
2625
from retrievers.bm25_retriever import BM25Retriever
2726
from retrievers.splade_retriever import SpladeRetriever
2827
from chunkers.semantic_chunker import BoundedSemanticChunker
2928
from chunkers.character_chunker import RecursiveCharacterTextSplitter
30-
from utils import Document
29+
from utils import Document, MySentenceTransformer
3130

3231

3332
class DocumentRetriever:
@@ -37,9 +36,9 @@ def __init__(self, device="cuda", num_results: int = 5, similarity_threshold: fl
3736
model_cache_dir: str = None, chunking_method: str = "character-based",
3837
chunker_breakpoint_threshold_amount: int = 10, client_timeout: int = 10):
3938
self.device = device
40-
self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=model_cache_dir,
41-
device=device,
42-
model_kwargs={"torch_dtype": torch.float32 if device == "cpu" else torch.float16})
39+
self.embedding_model = MySentenceTransformer("all-MiniLM-L6-v2", cache_folder=model_cache_dir,
40+
device=device,
41+
model_kwargs={"torch_dtype": torch.float32 if device == "cpu" else torch.float16})
4342
if keyword_retriever == "splade":
4443
splade_kwargs = {"cache_dir": model_cache_dir,
4544
"torch_dtype": torch.float32 if device == "cpu" else torch.float16,

retrievers/faiss_retriever.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@
22

33
import faiss
44
import numpy as np
5-
from sentence_transformers import SentenceTransformer
65

76
try:
8-
from ..utils import Document, cosine_similarity
7+
from ..utils import Document, cosine_similarity, MySentenceTransformer, SimilarLengthsBatchifyer
98
except:
10-
from utils import Document, cosine_similarity
9+
from utils import Document, cosine_similarity, MySentenceTransformer, SimilarLengthsBatchifyer
1110

1211

1312
class FaissRetriever:
1413

15-
def __init__(self, embedding_model: SentenceTransformer, num_results: int = 5, similarity_threshold: float = 0.5):
14+
def __init__(self, embedding_model: MySentenceTransformer, num_results: int = 5, similarity_threshold: float = 0.5):
1615
self.embedding_model = embedding_model
1716
self.num_results = num_results
1817
self.similarity_threshold = similarity_threshold
@@ -24,7 +23,7 @@ def add_documents(self, documents: List[Document]):
2423
if not documents:
2524
return
2625
self.documents = documents
27-
self.document_embeddings = self.embedding_model.encode([doc.page_content for doc in documents])
26+
self.document_embeddings = self.embedding_model.batch_encode([doc.page_content for doc in documents])
2827
self.index.add(self.document_embeddings)
2928

3029
def get_relevant_documents(self, query: str) -> List[Document]:

retrievers/splade_retriever.py

Lines changed: 5 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -11,80 +11,9 @@
1111
from scipy.sparse import csr_array
1212

1313
try:
14-
from ..utils import Document
14+
from ..utils import Document, SimilarLengthsBatchifyer
1515
except:
16-
from utils import Document
17-
18-
19-
class SimilarLengthsBatchifyer:
20-
"""
21-
Generator class to split samples into batches. Groups sample sequences
22-
of equal/similar length together to minimize the need for padding within a batch.
23-
"""
24-
def __init__(self, batch_size, inputs, max_padding_len=10):
25-
# Remember number of samples
26-
self.num_samples = len(inputs)
27-
28-
self.unique_lengths = set()
29-
self.length_to_sample_indices = {}
30-
31-
for i in range(0, len(inputs)):
32-
len_input = len(inputs[i])
33-
34-
self.unique_lengths.add(len_input)
35-
36-
# For each length, keep track of the indices of the samples that have this length
37-
# E.g.: self.length_to_sample_indices = { 3: [3,5,11], 4: [1,2], ...}
38-
if len_input in self.length_to_sample_indices:
39-
self.length_to_sample_indices[len_input].append(i)
40-
else:
41-
self.length_to_sample_indices[len_input] = [i]
42-
43-
# Use a dynamic batch size to speed up inference at a constant VRAM usage
44-
self.unique_lengths = sorted(list(self.unique_lengths))
45-
max_chars_per_batch = self.unique_lengths[-1] * batch_size
46-
self.length_to_batch_size = {length: int(max_chars_per_batch / (length * batch_size)) * batch_size for length in self.unique_lengths}
47-
48-
# Merge samples of similar lengths in those cases where the amount of samples
49-
# of a particular length is < dynamic batch size
50-
accum_len_diff = 0
51-
for i in range(1, len(self.unique_lengths)):
52-
if accum_len_diff >= max_padding_len:
53-
accum_len_diff = 0
54-
continue
55-
curr_len = self.unique_lengths[i]
56-
prev_len = self.unique_lengths[i-1]
57-
len_diff = curr_len - prev_len
58-
if (len_diff <= max_padding_len and
59-
(len(self.length_to_sample_indices[curr_len]) < self.length_to_batch_size[curr_len]
60-
or len(self.length_to_sample_indices[prev_len]) < self.length_to_batch_size[prev_len])):
61-
self.length_to_sample_indices[curr_len].extend(self.length_to_sample_indices[prev_len])
62-
self.length_to_sample_indices[prev_len] = []
63-
accum_len_diff += len_diff
64-
else:
65-
accum_len_diff = 0
66-
67-
def __len__(self):
68-
return self.num_samples
69-
70-
def __iter__(self):
71-
# Iterate over all possible sentence lengths
72-
for length in self.unique_lengths:
73-
74-
# Get indices of all samples for the current length
75-
# for example, all indices of samples with a length of 7
76-
sequence_indices = self.length_to_sample_indices[length]
77-
if len(sequence_indices) == 0:
78-
continue
79-
80-
dyn_batch_size = self.length_to_batch_size[length]
81-
82-
# Compute the number of batches
83-
num_batches = np.ceil(len(sequence_indices) / dyn_batch_size)
84-
85-
# Loop over all possible batches
86-
for batch_indices in np.array_split(sequence_indices, num_batches):
87-
yield batch_indices
16+
from utils import Document, SimilarLengthsBatchifyer
8817

8918

9019
def neg_dot_dist(x, y):
@@ -112,7 +41,9 @@ def __init__(self, splade_doc_tokenizer, splade_doc_model, splade_query_tokenize
11241
def compute_document_vectors(self, texts: List[str], batch_size: int) -> Tuple[List[List[int]], List[List[float]]]:
11342
indices = []
11443
values = []
115-
batchifyer = SimilarLengthsBatchifyer(batch_size, texts)
44+
tokenized_texts = self.splade_doc_tokenizer(texts, truncation=False, padding=False,
45+
return_tensors="np")["input_ids"]
46+
batchifyer = SimilarLengthsBatchifyer(batch_size, tokenized_texts)
11647
texts = np.array(texts)
11748
batch_indices = []
11849
for index_batch in batchifyer:

0 commit comments

Comments
 (0)