|
11 | 11 | from scipy.sparse import csr_array
|
12 | 12 |
|
13 | 13 | try:
|
14 |
| - from ..utils import Document |
| 14 | + from ..utils import Document, SimilarLengthsBatchifyer |
15 | 15 | except:
|
16 |
| - from utils import Document |
17 |
| - |
18 |
| - |
19 |
| -class SimilarLengthsBatchifyer: |
20 |
| - """ |
21 |
| - Generator class to split samples into batches. Groups sample sequences |
22 |
| - of equal/similar length together to minimize the need for padding within a batch. |
23 |
| - """ |
24 |
| - def __init__(self, batch_size, inputs, max_padding_len=10): |
25 |
| - # Remember number of samples |
26 |
| - self.num_samples = len(inputs) |
27 |
| - |
28 |
| - self.unique_lengths = set() |
29 |
| - self.length_to_sample_indices = {} |
30 |
| - |
31 |
| - for i in range(0, len(inputs)): |
32 |
| - len_input = len(inputs[i]) |
33 |
| - |
34 |
| - self.unique_lengths.add(len_input) |
35 |
| - |
36 |
| - # For each length, keep track of the indices of the samples that have this length |
37 |
| - # E.g.: self.length_to_sample_indices = { 3: [3,5,11], 4: [1,2], ...} |
38 |
| - if len_input in self.length_to_sample_indices: |
39 |
| - self.length_to_sample_indices[len_input].append(i) |
40 |
| - else: |
41 |
| - self.length_to_sample_indices[len_input] = [i] |
42 |
| - |
43 |
| - # Use a dynamic batch size to speed up inference at a constant VRAM usage |
44 |
| - self.unique_lengths = sorted(list(self.unique_lengths)) |
45 |
| - max_chars_per_batch = self.unique_lengths[-1] * batch_size |
46 |
| - self.length_to_batch_size = {length: int(max_chars_per_batch / (length * batch_size)) * batch_size for length in self.unique_lengths} |
47 |
| - |
48 |
| - # Merge samples of similar lengths in those cases where the amount of samples |
49 |
| - # of a particular length is < dynamic batch size |
50 |
| - accum_len_diff = 0 |
51 |
| - for i in range(1, len(self.unique_lengths)): |
52 |
| - if accum_len_diff >= max_padding_len: |
53 |
| - accum_len_diff = 0 |
54 |
| - continue |
55 |
| - curr_len = self.unique_lengths[i] |
56 |
| - prev_len = self.unique_lengths[i-1] |
57 |
| - len_diff = curr_len - prev_len |
58 |
| - if (len_diff <= max_padding_len and |
59 |
| - (len(self.length_to_sample_indices[curr_len]) < self.length_to_batch_size[curr_len] |
60 |
| - or len(self.length_to_sample_indices[prev_len]) < self.length_to_batch_size[prev_len])): |
61 |
| - self.length_to_sample_indices[curr_len].extend(self.length_to_sample_indices[prev_len]) |
62 |
| - self.length_to_sample_indices[prev_len] = [] |
63 |
| - accum_len_diff += len_diff |
64 |
| - else: |
65 |
| - accum_len_diff = 0 |
66 |
| - |
67 |
| - def __len__(self): |
68 |
| - return self.num_samples |
69 |
| - |
70 |
| - def __iter__(self): |
71 |
| - # Iterate over all possible sentence lengths |
72 |
| - for length in self.unique_lengths: |
73 |
| - |
74 |
| - # Get indices of all samples for the current length |
75 |
| - # for example, all indices of samples with a length of 7 |
76 |
| - sequence_indices = self.length_to_sample_indices[length] |
77 |
| - if len(sequence_indices) == 0: |
78 |
| - continue |
79 |
| - |
80 |
| - dyn_batch_size = self.length_to_batch_size[length] |
81 |
| - |
82 |
| - # Compute the number of batches |
83 |
| - num_batches = np.ceil(len(sequence_indices) / dyn_batch_size) |
84 |
| - |
85 |
| - # Loop over all possible batches |
86 |
| - for batch_indices in np.array_split(sequence_indices, num_batches): |
87 |
| - yield batch_indices |
| 16 | + from utils import Document, SimilarLengthsBatchifyer |
88 | 17 |
|
89 | 18 |
|
90 | 19 | def neg_dot_dist(x, y):
|
@@ -112,7 +41,9 @@ def __init__(self, splade_doc_tokenizer, splade_doc_model, splade_query_tokenize
|
112 | 41 | def compute_document_vectors(self, texts: List[str], batch_size: int) -> Tuple[List[List[int]], List[List[float]]]:
|
113 | 42 | indices = []
|
114 | 43 | values = []
|
115 |
| - batchifyer = SimilarLengthsBatchifyer(batch_size, texts) |
| 44 | + tokenized_texts = self.splade_doc_tokenizer(texts, truncation=False, padding=False, |
| 45 | + return_tensors="np")["input_ids"] |
| 46 | + batchifyer = SimilarLengthsBatchifyer(batch_size, tokenized_texts) |
116 | 47 | texts = np.array(texts)
|
117 | 48 | batch_indices = []
|
118 | 49 | for index_batch in batchifyer:
|
|
0 commit comments