Skip to content

Commit 8e2c6aa

Browse files
authored
Convert some complex PriorityQueue implementations to use comparators (#14817)
These require more refactoring than previous PRs. Remaining PriorityQueue subclasses are tricky, are public API, or involve multiple further subclasses
1 parent 1f4da95 commit 8e2c6aa

File tree

17 files changed

+168
-259
lines changed

17 files changed

+168
-259
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ Other
8585
* GITHUB#14613: Rewrite APIJAR extractor to use Java 24 classfile API and kill ASM dependency also for build system. (Uwe Schindler)
8686

8787
* GITHUB#14705: Use Comparators for some PriorityQueue implementations. (Simon Cooper)
88+
* GITHUB#14761: Use more Comparators for PriorityQueue implementations. (Simon Cooper)
89+
* GITHUB#14817: Refactor some complex uses of PriorityQueue to use Comparators. (Simon Cooper)
8890

8991
======================= Lucene 10.3.0 =======================
9092

lucene/core/src/java/org/apache/lucene/index/MultiTermsEnum.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ protected boolean lessThan(TermsEnumWithSlice termsA, TermsEnumWithSlice termsB)
385385
}
386386

387387
/**
388-
* Add the {@link #top()} slice as well as all slices that are positionned on the same term to
388+
* Add the {@link #top()} slice as well as all slices that are positioned on the same term to
389389
* {@code tops} and return how many of them there are.
390390
*/
391391
int fillTop(TermsEnumWithSlice[] tops) {
@@ -402,7 +402,7 @@ int fillTop(TermsEnumWithSlice[] tops) {
402402
final int index = stack[--stackLen];
403403
final int leftChild = index << 1;
404404
for (int child = leftChild, end = Math.min(size, leftChild + 1); child <= end; ++child) {
405-
TermsEnumWithSlice te = get(child);
405+
TermsEnumWithSlice te = (TermsEnumWithSlice) getHeapArray()[child];
406406
if (te.compareTermTo(tops[0]) == 0) {
407407
tops[numTop++] = te;
408408
stack[stackLen++] = child;
@@ -411,10 +411,6 @@ int fillTop(TermsEnumWithSlice[] tops) {
411411
}
412412
return numTop;
413413
}
414-
415-
private TermsEnumWithSlice get(int i) {
416-
return (TermsEnumWithSlice) getHeapArray()[i];
417-
}
418414
}
419415

420416
@Override

lucene/core/src/java/org/apache/lucene/search/DisjunctionMatchesIterator.java

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -191,23 +191,20 @@ static MatchesIterator fromSubIterators(List<MatchesIterator> mis) throws IOExce
191191

192192
private DisjunctionMatchesIterator(List<MatchesIterator> matches) throws IOException {
193193
queue =
194-
new PriorityQueue<MatchesIterator>(matches.size()) {
195-
@Override
196-
protected boolean lessThan(MatchesIterator a, MatchesIterator b) {
197-
if (a.startPosition() == -1 && b.startPosition() == -1) {
198-
try {
199-
return a.startOffset() < b.startOffset()
200-
|| (a.startOffset() == b.startOffset() && a.endOffset() < b.endOffset())
201-
|| (a.startOffset() == b.startOffset() && a.endOffset() == b.endOffset());
202-
} catch (IOException e) {
203-
throw new IllegalArgumentException("Failed to retrieve term offset", e);
194+
PriorityQueue.usingLessThan(
195+
matches.size(),
196+
(a, b) -> {
197+
if (a.startPosition() == -1 && b.startPosition() == -1) {
198+
try {
199+
return a.startOffset() < b.startOffset()
200+
|| (a.startOffset() == b.startOffset() && a.endOffset() <= b.endOffset());
201+
} catch (IOException e) {
202+
throw new IllegalArgumentException("Failed to retrieve term offset", e);
203+
}
204204
}
205-
}
206-
return a.startPosition() < b.startPosition()
207-
|| (a.startPosition() == b.startPosition() && a.endPosition() < b.endPosition())
208-
|| (a.startPosition() == b.startPosition() && a.endPosition() == b.endPosition());
209-
}
210-
};
205+
return a.startPosition() < b.startPosition()
206+
|| (a.startPosition() == b.startPosition() && a.endPosition() <= b.endPosition());
207+
});
211208
for (MatchesIterator mi : matches) {
212209
if (mi.next()) {
213210
queue.add(mi);

lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.Comparator;
2222
import java.util.Iterator;
2323
import java.util.NoSuchElementException;
24+
import java.util.function.IntFunction;
2425
import java.util.function.Supplier;
2526

2627
/**
@@ -320,6 +321,30 @@ public final boolean remove(T element) {
320321
return false;
321322
}
322323

324+
/**
325+
* Moves the contents of this queue into a new array created by {@code newArray}, lowest items
326+
* first
327+
*/
328+
public T[] drainToArrayLowestFirst(IntFunction<T[]> newArray) {
329+
T[] array = newArray.apply(size);
330+
for (int i = 0; i < array.length; i++) {
331+
array[i] = pop();
332+
}
333+
return array;
334+
}
335+
336+
/**
337+
* Moves the contents of this queue into a new array created by {@code newArray}, highest items
338+
* first
339+
*/
340+
public T[] drainToArrayHighestFirst(IntFunction<T[]> newArray) {
341+
T[] array = newArray.apply(size);
342+
for (int i = array.length - 1; i >= 0; i--) {
343+
array[i] = pop();
344+
}
345+
return array;
346+
}
347+
323348
private boolean upHeap(int origPos) {
324349
int i = origPos;
325350
T node = heap[i]; // save bottom node

lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.apache.lucene.search.grouping;
1818

1919
import java.io.IOException;
20+
import java.util.Comparator;
2021
import org.apache.lucene.index.IndexWriter;
2122
import org.apache.lucene.index.LeafReaderContext;
2223
import org.apache.lucene.search.DocIdSetIterator;
@@ -92,7 +93,7 @@ public class BlockGroupingCollector extends SimpleCollector {
9293
private int groupEndDocID;
9394
private DocIdSetIterator lastDocPerGroupBits;
9495
private Scorable scorer;
95-
private final GroupQueue groupQueue;
96+
private final PriorityQueue<OneGroup> groupQueue;
9697
private boolean groupCompetes;
9798

9899
private static final class OneGroup {
@@ -105,34 +106,31 @@ private static final class OneGroup {
105106
int comparatorSlot;
106107
}
107108

108-
// Sorts by groupSort. Not static -- uses comparators, reversed
109-
private final class GroupQueue extends PriorityQueue<OneGroup> {
110-
111-
public GroupQueue(int size) {
112-
super(size);
113-
}
114-
115-
@Override
116-
protected boolean lessThan(final OneGroup group1, final OneGroup group2) {
117-
118-
// System.out.println(" ltcheck");
119-
assert group1 != group2;
120-
assert group1.comparatorSlot != group2.comparatorSlot;
121-
122-
final int numComparators = comparators.length;
123-
for (int compIDX = 0; compIDX < numComparators; compIDX++) {
124-
final int c =
125-
reversed[compIDX]
126-
* comparators[compIDX].compare(group1.comparatorSlot, group2.comparatorSlot);
127-
if (c != 0) {
128-
// Short circuit
129-
return c > 0;
130-
}
131-
}
132-
133-
// Break ties by docID; lower docID is always sorted first
134-
return group1.topGroupDoc > group2.topGroupDoc;
135-
}
109+
private PriorityQueue<OneGroup> createGroupQueue(int size) {
110+
// Sorts by groupSort
111+
return PriorityQueue.usingComparator(
112+
size,
113+
((Comparator<OneGroup>)
114+
(group1, group2) -> {
115+
assert group1 != group2;
116+
assert group1.comparatorSlot != group2.comparatorSlot;
117+
118+
final int numComparators = comparators.length;
119+
for (int compIDX = 0; compIDX < numComparators; compIDX++) {
120+
final int c =
121+
reversed[compIDX]
122+
* comparators[compIDX].compare(
123+
group1.comparatorSlot, group2.comparatorSlot);
124+
if (c != 0) {
125+
// Short circuit
126+
return c;
127+
}
128+
}
129+
return 0;
130+
})
131+
.thenComparingInt(
132+
g -> g.topGroupDoc) // Break ties by docID; lower docID is always sorted first
133+
.reversed());
136134
}
137135

138136
// Called when we transition to another group; if the
@@ -221,7 +219,7 @@ public BlockGroupingCollector(
221219
throw new IllegalArgumentException("topNGroups must be >= 1 (got " + topNGroups + ")");
222220
}
223221

224-
groupQueue = new GroupQueue(topNGroups);
222+
groupQueue = createGroupQueue(topNGroups);
225223
pendingSubDocs = new int[10];
226224
if (needsScores) {
227225
pendingSubScores = new float[10];

lucene/highlighter/src/java/org/apache/lucene/search/highlight/Highlighter.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,7 @@ public final TextFragment[] getBestTextFragments(
292292
}
293293

294294
// return the most relevant fragments
295-
TextFragment[] frag = new TextFragment[fragQueue.size()];
296-
for (int i = frag.length - 1; i >= 0; i--) {
297-
frag[i] = fragQueue.pop();
298-
}
295+
TextFragment[] frag = fragQueue.drainToArrayHighestFirst(TextFragment[]::new);
299296

300297
// merge any contiguous fragments to improve readability
301298
if (mergeContiguousFragments) {

lucene/highlighter/src/java/org/apache/lucene/search/matchhighlight/PassageSelector.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,7 @@ public List<Passage> pickBest(
168168
// Collect from the priority queue (reverse the order so that highest-scoring are first).
169169
Passage[] passages;
170170
if (pq.size() > 0) {
171-
passages = new Passage[pq.size()];
172-
for (int i = pq.size(); --i >= 0; ) {
173-
passages[i] = pq.pop();
174-
}
171+
passages = pq.drainToArrayHighestFirst(Passage[]::new);
175172
} else {
176173
// Handle the default, no highlighting markers case.
177174
passages = pickDefaultPassage(value, maxPassageWindow, maxPassages, permittedPassageRanges);

lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,7 @@ protected TopDocs exactSearch(
148148
queue.pop();
149149
}
150150

151-
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
152-
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
153-
topScoreDocs[i] = queue.pop();
154-
}
151+
ScoreDoc[] topScoreDocs = queue.drainToArrayHighestFirst(ScoreDoc[]::new);
155152

156153
TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
157154
return new TopDocs(totalHits, topScoreDocs);

lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,7 @@ protected TopDocs exactSearch(
147147
queue.pop();
148148
}
149149

150-
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
151-
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
152-
topScoreDocs[i] = queue.pop();
153-
}
150+
ScoreDoc[] topScoreDocs = queue.drainToArrayHighestFirst(ScoreDoc[]::new);
154151

155152
TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
156153
return new TopDocs(totalHits, topScoreDocs);

lucene/misc/src/java/org/apache/lucene/misc/HighFreqTerms.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,7 @@ public static TermStats[] getHighFreqTerms(
124124
}
125125
}
126126

127-
TermStats[] result = new TermStats[tiq.size()];
128-
// we want highest first so we read the queue and populate the array
129-
// starting at the end and work backwards
130-
int count = tiq.size() - 1;
131-
while (tiq.size() != 0) {
132-
result[count] = tiq.pop();
133-
count--;
134-
}
135-
return result;
127+
return tiq.drainToArrayHighestFirst(TermStats[]::new);
136128
}
137129

138130
/** Compares terms by docTermFreq */

0 commit comments

Comments
 (0)