diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 6a7f6b897141..17c927fd3cd0 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -81,6 +81,8 @@ Other * GITHUB#14613: Rewrite APIJAR extractor to use Java 24 classfile API and kill ASM dependency also for build system. (Uwe Schindler) * GITHUB#14705: Use Comparators for some PriorityQueue implementations. (Simon Cooper) +* GITHUB#14761: Use more Comparators for PriorityQueue implementations. (Simon Cooper) +* GITHUB#14817: Refactor some complex uses of PriorityQueue to use Comparators. (Simon Cooper) ======================= Lucene 10.3.0 ======================= diff --git a/lucene/core/src/java/org/apache/lucene/index/MultiTermsEnum.java b/lucene/core/src/java/org/apache/lucene/index/MultiTermsEnum.java index 144410f4cd9b..420c58474813 100644 --- a/lucene/core/src/java/org/apache/lucene/index/MultiTermsEnum.java +++ b/lucene/core/src/java/org/apache/lucene/index/MultiTermsEnum.java @@ -385,7 +385,7 @@ protected boolean lessThan(TermsEnumWithSlice termsA, TermsEnumWithSlice termsB) } /** - * Add the {@link #top()} slice as well as all slices that are positionned on the same term to + * Add the {@link #top()} slice as well as all slices that are positioned on the same term to * {@code tops} and return how many of them there are. */ int fillTop(TermsEnumWithSlice[] tops) { @@ -402,7 +402,7 @@ int fillTop(TermsEnumWithSlice[] tops) { final int index = stack[--stackLen]; final int leftChild = index << 1; for (int child = leftChild, end = Math.min(size, leftChild + 1); child <= end; ++child) { - TermsEnumWithSlice te = get(child); + TermsEnumWithSlice te = (TermsEnumWithSlice) getHeapArray()[child]; if (te.compareTermTo(tops[0]) == 0) { tops[numTop++] = te; stack[stackLen++] = child; @@ -411,10 +411,6 @@ int fillTop(TermsEnumWithSlice[] tops) { } return numTop; } - - private TermsEnumWithSlice get(int i) { - return (TermsEnumWithSlice) getHeapArray()[i]; - } } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMatchesIterator.java b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMatchesIterator.java index 13852b1b9eed..3d1c35838f38 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMatchesIterator.java +++ b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMatchesIterator.java @@ -191,23 +191,20 @@ static MatchesIterator fromSubIterators(List mis) throws IOExce private DisjunctionMatchesIterator(List matches) throws IOException { queue = - new PriorityQueue(matches.size()) { - @Override - protected boolean lessThan(MatchesIterator a, MatchesIterator b) { - if (a.startPosition() == -1 && b.startPosition() == -1) { - try { - return a.startOffset() < b.startOffset() - || (a.startOffset() == b.startOffset() && a.endOffset() < b.endOffset()) - || (a.startOffset() == b.startOffset() && a.endOffset() == b.endOffset()); - } catch (IOException e) { - throw new IllegalArgumentException("Failed to retrieve term offset", e); + PriorityQueue.usingLessThan( + matches.size(), + (a, b) -> { + if (a.startPosition() == -1 && b.startPosition() == -1) { + try { + return a.startOffset() < b.startOffset() + || (a.startOffset() == b.startOffset() && a.endOffset() <= b.endOffset()); + } catch (IOException e) { + throw new IllegalArgumentException("Failed to retrieve term offset", e); + } } - } - return a.startPosition() < b.startPosition() - || (a.startPosition() == b.startPosition() && a.endPosition() < b.endPosition()) - || (a.startPosition() == b.startPosition() && a.endPosition() == b.endPosition()); - } - }; + return a.startPosition() < b.startPosition() + || (a.startPosition() == b.startPosition() && a.endPosition() <= b.endPosition()); + }); for (MatchesIterator mi : matches) { if (mi.next()) { queue.add(mi); diff --git a/lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java b/lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java index efb0ebeae9db..8b31baed9e91 100644 --- a/lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java +++ b/lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java @@ -21,6 +21,7 @@ import java.util.Comparator; import java.util.Iterator; import java.util.NoSuchElementException; +import java.util.function.IntFunction; import java.util.function.Supplier; /** @@ -320,6 +321,30 @@ public final boolean remove(T element) { return false; } + /** + * Moves the contents of this queue into a new array created by {@code newArray}, lowest items + * first + */ + public T[] drainToArrayLowestFirst(IntFunction newArray) { + T[] array = newArray.apply(size); + for (int i = 0; i < array.length; i++) { + array[i] = pop(); + } + return array; + } + + /** + * Moves the contents of this queue into a new array created by {@code newArray}, highest items + * first + */ + public T[] drainToArrayHighestFirst(IntFunction newArray) { + T[] array = newArray.apply(size); + for (int i = array.length - 1; i >= 0; i--) { + array[i] = pop(); + } + return array; + } + private boolean upHeap(int origPos) { int i = origPos; T node = heap[i]; // save bottom node diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java index 779b62291a69..1f6a473f0e65 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java @@ -17,6 +17,7 @@ package org.apache.lucene.search.grouping; import java.io.IOException; +import java.util.Comparator; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; @@ -92,7 +93,7 @@ public class BlockGroupingCollector extends SimpleCollector { private int groupEndDocID; private DocIdSetIterator lastDocPerGroupBits; private Scorable scorer; - private final GroupQueue groupQueue; + private final PriorityQueue groupQueue; private boolean groupCompetes; private static final class OneGroup { @@ -105,34 +106,31 @@ private static final class OneGroup { int comparatorSlot; } - // Sorts by groupSort. Not static -- uses comparators, reversed - private final class GroupQueue extends PriorityQueue { - - public GroupQueue(int size) { - super(size); - } - - @Override - protected boolean lessThan(final OneGroup group1, final OneGroup group2) { - - // System.out.println(" ltcheck"); - assert group1 != group2; - assert group1.comparatorSlot != group2.comparatorSlot; - - final int numComparators = comparators.length; - for (int compIDX = 0; compIDX < numComparators; compIDX++) { - final int c = - reversed[compIDX] - * comparators[compIDX].compare(group1.comparatorSlot, group2.comparatorSlot); - if (c != 0) { - // Short circuit - return c > 0; - } - } - - // Break ties by docID; lower docID is always sorted first - return group1.topGroupDoc > group2.topGroupDoc; - } + private PriorityQueue createGroupQueue(int size) { + // Sorts by groupSort + return PriorityQueue.usingComparator( + size, + ((Comparator) + (group1, group2) -> { + assert group1 != group2; + assert group1.comparatorSlot != group2.comparatorSlot; + + final int numComparators = comparators.length; + for (int compIDX = 0; compIDX < numComparators; compIDX++) { + final int c = + reversed[compIDX] + * comparators[compIDX].compare( + group1.comparatorSlot, group2.comparatorSlot); + if (c != 0) { + // Short circuit + return c; + } + } + return 0; + }) + .thenComparingInt( + g -> g.topGroupDoc) // Break ties by docID; lower docID is always sorted first + .reversed()); } // Called when we transition to another group; if the @@ -221,7 +219,7 @@ public BlockGroupingCollector( throw new IllegalArgumentException("topNGroups must be >= 1 (got " + topNGroups + ")"); } - groupQueue = new GroupQueue(topNGroups); + groupQueue = createGroupQueue(topNGroups); pendingSubDocs = new int[10]; if (needsScores) { pendingSubScores = new float[10]; diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/Highlighter.java b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/Highlighter.java index 1a17b3267db3..19aa18425830 100644 --- a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/Highlighter.java +++ b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/Highlighter.java @@ -292,10 +292,7 @@ public final TextFragment[] getBestTextFragments( } // return the most relevant fragments - TextFragment[] frag = new TextFragment[fragQueue.size()]; - for (int i = frag.length - 1; i >= 0; i--) { - frag[i] = fragQueue.pop(); - } + TextFragment[] frag = fragQueue.drainToArrayHighestFirst(TextFragment[]::new); // merge any contiguous fragments to improve readability if (mergeContiguousFragments) { diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/matchhighlight/PassageSelector.java b/lucene/highlighter/src/java/org/apache/lucene/search/matchhighlight/PassageSelector.java index b3d929e3e246..d70a6c8fc113 100644 --- a/lucene/highlighter/src/java/org/apache/lucene/search/matchhighlight/PassageSelector.java +++ b/lucene/highlighter/src/java/org/apache/lucene/search/matchhighlight/PassageSelector.java @@ -168,10 +168,7 @@ public List pickBest( // Collect from the priority queue (reverse the order so that highest-scoring are first). Passage[] passages; if (pq.size() > 0) { - passages = new Passage[pq.size()]; - for (int i = pq.size(); --i >= 0; ) { - passages[i] = pq.pop(); - } + passages = pq.drainToArrayHighestFirst(Passage[]::new); } else { // Handle the default, no highlighting markers case. passages = pickDefaultPassage(value, maxPassageWindow, maxPassages, permittedPassageRanges); diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java index d76b2bed75af..f3e1c518bdd7 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java @@ -148,10 +148,7 @@ protected TopDocs exactSearch( queue.pop(); } - ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()]; - for (int i = topScoreDocs.length - 1; i >= 0; i--) { - topScoreDocs[i] = queue.pop(); - } + ScoreDoc[] topScoreDocs = queue.drainToArrayHighestFirst(ScoreDoc[]::new); TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation); return new TopDocs(totalHits, topScoreDocs); diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java index 3cffa44772fe..d98fc62fdbe9 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java @@ -147,10 +147,7 @@ protected TopDocs exactSearch( queue.pop(); } - ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()]; - for (int i = topScoreDocs.length - 1; i >= 0; i--) { - topScoreDocs[i] = queue.pop(); - } + ScoreDoc[] topScoreDocs = queue.drainToArrayHighestFirst(ScoreDoc[]::new); TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation); return new TopDocs(totalHits, topScoreDocs); diff --git a/lucene/misc/src/java/org/apache/lucene/misc/HighFreqTerms.java b/lucene/misc/src/java/org/apache/lucene/misc/HighFreqTerms.java index 0da7f5d9510b..a3e862299a4a 100644 --- a/lucene/misc/src/java/org/apache/lucene/misc/HighFreqTerms.java +++ b/lucene/misc/src/java/org/apache/lucene/misc/HighFreqTerms.java @@ -124,15 +124,7 @@ public static TermStats[] getHighFreqTerms( } } - TermStats[] result = new TermStats[tiq.size()]; - // we want highest first so we read the queue and populate the array - // starting at the end and work backwards - int count = tiq.size() - 1; - while (tiq.size() != 0) { - result[count] = tiq.pop(); - count--; - } - return result; + return tiq.drainToArrayHighestFirst(TermStats[]::new); } /** Compares terms by docTermFreq */ diff --git a/lucene/queries/src/java/org/apache/lucene/queries/spans/NearSpansUnordered.java b/lucene/queries/src/java/org/apache/lucene/queries/spans/NearSpansUnordered.java index 5c15b7520a4b..ca7411c1e6b3 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/spans/NearSpansUnordered.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/spans/NearSpansUnordered.java @@ -17,6 +17,7 @@ package org.apache.lucene.queries.spans; import java.io.IOException; +import java.util.Comparator; import java.util.List; import org.apache.lucene.util.PriorityQueue; @@ -28,96 +29,72 @@ public class NearSpansUnordered extends ConjunctionSpans { private final int allowedSlop; - private SpanTotalLengthEndPositionWindow spanWindow; + private final PriorityQueue spanWindow; + + private int totalSpanLength; + private int maxEndPosition; public NearSpansUnordered(int allowedSlop, List subSpans) throws IOException { super(subSpans); this.allowedSlop = allowedSlop; - this.spanWindow = new SpanTotalLengthEndPositionWindow(); + this.spanWindow = + PriorityQueue.usingComparator( + super.subSpans.length, + Comparator.comparingInt(Spans::startPosition).thenComparingInt(Spans::endPosition)); } - /** Maintain totalSpanLength and maxEndPosition */ - private class SpanTotalLengthEndPositionWindow extends PriorityQueue { - int totalSpanLength; - int maxEndPosition; - - public SpanTotalLengthEndPositionWindow() { - super(subSpans.length); - } - - @Override - protected final boolean lessThan(Spans spans1, Spans spans2) { - return positionsOrdered(spans1, spans2); - } - - void startDocument() throws IOException { - clear(); - totalSpanLength = 0; - maxEndPosition = -1; - for (Spans spans : subSpans) { - assert spans.startPosition() == -1; - spans.nextStartPosition(); - assert spans.startPosition() != NO_MORE_POSITIONS; - add(spans); - if (spans.endPosition() > maxEndPosition) { - maxEndPosition = spans.endPosition(); - } - int spanLength = spans.endPosition() - spans.startPosition(); - assert spanLength >= 0; - totalSpanLength += spanLength; - } - } - - boolean nextPosition() throws IOException { - Spans topSpans = top(); - assert topSpans.startPosition() != NO_MORE_POSITIONS; - int spanLength = topSpans.endPosition() - topSpans.startPosition(); - int nextStartPos = topSpans.nextStartPosition(); - if (nextStartPos == NO_MORE_POSITIONS) { - return false; + private void startDocument() throws IOException { + spanWindow.clear(); + totalSpanLength = 0; + maxEndPosition = -1; + for (Spans spans : subSpans) { + assert spans.startPosition() == -1; + spans.nextStartPosition(); + assert spans.startPosition() != NO_MORE_POSITIONS; + spanWindow.add(spans); + if (spans.endPosition() > maxEndPosition) { + maxEndPosition = spans.endPosition(); } - totalSpanLength -= spanLength; - spanLength = topSpans.endPosition() - topSpans.startPosition(); + int spanLength = spans.endPosition() - spans.startPosition(); + assert spanLength >= 0; totalSpanLength += spanLength; - if (topSpans.endPosition() > maxEndPosition) { - maxEndPosition = topSpans.endPosition(); - } - updateTop(); - return true; } + } - boolean atMatch() { - boolean res = (maxEndPosition - top().startPosition() - totalSpanLength) <= allowedSlop; - return res; + private boolean nextPosition() throws IOException { + Spans topSpans = spanWindow.top(); + assert topSpans.startPosition() != NO_MORE_POSITIONS; + int spanLength = topSpans.endPosition() - topSpans.startPosition(); + int nextStartPos = topSpans.nextStartPosition(); + if (nextStartPos == NO_MORE_POSITIONS) { + return false; + } + totalSpanLength -= spanLength; + spanLength = topSpans.endPosition() - topSpans.startPosition(); + totalSpanLength += spanLength; + if (topSpans.endPosition() > maxEndPosition) { + maxEndPosition = topSpans.endPosition(); } + spanWindow.updateTop(); + return true; } - /** - * Check whether two Spans in the same document are ordered with possible overlap. - * - * @return true iff spans1 starts before spans2 or the spans start at the same position, and - * spans1 ends before spans2. - */ - static boolean positionsOrdered(Spans spans1, Spans spans2) { - assert spans1.docID() == spans2.docID() - : "doc1 " + spans1.docID() + " != doc2 " + spans2.docID(); - int start1 = spans1.startPosition(); - int start2 = spans2.startPosition(); - return (start1 == start2) ? (spans1.endPosition() < spans2.endPosition()) : (start1 < start2); + private boolean atMatch() { + return (maxEndPosition - spanWindow.top().startPosition() - totalSpanLength) <= allowedSlop; } @Override boolean twoPhaseCurrentDocMatches() throws IOException { // at doc with all subSpans - spanWindow.startDocument(); + startDocument(); while (true) { - if (spanWindow.atMatch()) { + if (atMatch()) { atFirstInCurrentDoc = true; oneExhaustedInCurrentDoc = false; return true; } - if (!spanWindow.nextPosition()) { + if (!nextPosition()) { return false; } } @@ -132,11 +109,11 @@ public int nextStartPosition() throws IOException { assert spanWindow.top().startPosition() != -1; assert spanWindow.top().startPosition() != NO_MORE_POSITIONS; while (true) { - if (!spanWindow.nextPosition()) { + if (!nextPosition()) { oneExhaustedInCurrentDoc = true; return NO_MORE_POSITIONS; } - if (spanWindow.atMatch()) { + if (atMatch()) { return spanWindow.top().startPosition(); } } @@ -152,14 +129,12 @@ public int startPosition() { @Override public int endPosition() { - return atFirstInCurrentDoc - ? -1 - : oneExhaustedInCurrentDoc ? NO_MORE_POSITIONS : spanWindow.maxEndPosition; + return atFirstInCurrentDoc ? -1 : oneExhaustedInCurrentDoc ? NO_MORE_POSITIONS : maxEndPosition; } @Override public int width() { - return spanWindow.maxEndPosition - spanWindow.top().startPosition(); + return maxEndPosition - spanWindow.top().startPosition(); } @Override diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/SortedNumericDocValuesMultiRangeQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/SortedNumericDocValuesMultiRangeQuery.java index d6e679200af6..46aeaac6a5ed 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/SortedNumericDocValuesMultiRangeQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/SortedNumericDocValuesMultiRangeQuery.java @@ -17,6 +17,7 @@ package org.apache.lucene.sandbox.search; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; import java.util.Comparator; import java.util.Iterator; @@ -38,7 +39,6 @@ import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; -import org.apache.lucene.util.PriorityQueue; /** * A union multiple ranges over SortedNumericDocValuesField @@ -91,30 +91,32 @@ private static NavigableSet resolveOverlaps( Comparator.comparing(r -> r.lower) // .thenComparing(r -> r.upper)// have to ignore upper boundary for .floor() lookups ); - PriorityQueue heap = - PriorityQueue.usingComparator( - clauses.size() * 2, - Comparator.comparingLong(Edge::getValue) - .thenComparing(e -> e.point, Comparator.reverseOrder())); // points first + List clauseEdges = new ArrayList<>(clauses.size() * 2); + for (DocValuesMultiRangeQuery.LongRange r : clauses) { long cmp = r.lower - r.upper; if (cmp == 0) { - heap.add(Edge.createPoint(r)); + clauseEdges.add(Edge.createPoint(r)); } else { if (cmp < 0) { - heap.add(new Edge(r, false)); - heap.add(new Edge(r, true)); + clauseEdges.add(new Edge(r, false)); + clauseEdges.add(new Edge(r, true)); } // else drop reverse ranges } } - int totalEdges = heap.size(); + + // sort by edge value, then points first + clauseEdges.sort( + Comparator.comparingLong(Edge::getValue) + .thenComparing(e -> e.point, Comparator.reverseOrder())); + int depth = 0; Edge started = null; - for (int i = 0; i < totalEdges; i++) { - Edge smallest = heap.pop(); + for (int i = 0; i < clauseEdges.size(); i++) { + Edge smallest = clauseEdges.get(i); if (depth == 0 && smallest.point) { - if (i < totalEdges - 1) { // the point sits on the edge of the range - if (smallest.getValue() == heap.top().getValue()) { + if (i < clauseEdges.size() - 1) { // the point sits on the edge of the range + if (smallest.getValue() == clauseEdges.get(i + 1).getValue()) { continue; } } diff --git a/lucene/suggest/src/java/org/apache/lucene/search/suggest/Lookup.java b/lucene/suggest/src/java/org/apache/lucene/search/suggest/Lookup.java index 2fc7108c2a7c..05915d5defb1 100644 --- a/lucene/suggest/src/java/org/apache/lucene/search/suggest/Lookup.java +++ b/lucene/suggest/src/java/org/apache/lucene/search/suggest/Lookup.java @@ -31,7 +31,6 @@ import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.PriorityQueue; /** * Simple Lookup interface for {@link CharSequence} suggestions. @@ -135,34 +134,6 @@ public int compare(CharSequence o1, CharSequence o2) { } } - /** A {@link PriorityQueue} collecting a fixed size of high priority {@link LookupResult} */ - public static final class LookupPriorityQueue extends PriorityQueue { - // TODO: should we move this out of the interface into a utility class? - /** Creates a new priority queue of the specified size. */ - public LookupPriorityQueue(int size) { - super(size); - } - - @Override - protected boolean lessThan(LookupResult a, LookupResult b) { - return a.value < b.value; - } - - /** - * Returns the top N results in descending order. - * - * @return the top N results in descending order. - */ - public LookupResult[] getResults() { - int size = size(); - LookupResult[] res = new LookupResult[size]; - for (int i = size - 1; i >= 0; i--) { - res[i] = pop(); - } - return res; - } - } - /** Sole constructor. (For invocation by subclass constructors, typically implicit.) */ public Lookup() {} diff --git a/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/SuggestScoreDocPriorityQueue.java b/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/SuggestScoreDocPriorityQueue.java deleted file mode 100644 index fee4d35c4d22..000000000000 --- a/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/SuggestScoreDocPriorityQueue.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.lucene.search.suggest.document; - -import org.apache.lucene.search.suggest.Lookup; -import org.apache.lucene.search.suggest.document.TopSuggestDocs.SuggestScoreDoc; -import org.apache.lucene.util.PriorityQueue; - -/** - * Bounded priority queue for {@link SuggestScoreDoc}s. Priority is based on {@link - * SuggestScoreDoc#score} and tie is broken by {@link SuggestScoreDoc#doc} - */ -final class SuggestScoreDocPriorityQueue extends PriorityQueue { - /** Creates a new priority queue of the specified size. */ - public SuggestScoreDocPriorityQueue(int size) { - super(size); - } - - @Override - protected boolean lessThan(SuggestScoreDoc a, SuggestScoreDoc b) { - if (a.score == b.score) { - // tie break by completion key - int cmp = Lookup.CHARSEQUENCE_COMPARATOR.compare(a.key, b.key); - // prefer smaller doc id, in case of a tie - return cmp != 0 ? cmp > 0 : a.doc > b.doc; - } - return a.score < b.score; - } - - /** Returns the top N results in descending order. */ - public SuggestScoreDoc[] getResults() { - int size = size(); - SuggestScoreDoc[] res = new SuggestScoreDoc[size]; - for (int i = size - 1; i >= 0; i--) { - res[i] = pop(); - } - return res; - } -} diff --git a/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/TopSuggestDocs.java b/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/TopSuggestDocs.java index b19c9f3950b1..4648b977a483 100644 --- a/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/TopSuggestDocs.java +++ b/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/TopSuggestDocs.java @@ -16,10 +16,12 @@ */ package org.apache.lucene.search.suggest.document; +import java.util.Comparator; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.suggest.Lookup; +import org.apache.lucene.util.PriorityQueue; /** * {@link org.apache.lucene.search.TopDocs} wrapper with an additional CharSequence key per {@link @@ -33,6 +35,16 @@ public class TopSuggestDocs extends TopDocs { public static final TopSuggestDocs EMPTY = new TopSuggestDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new SuggestScoreDoc[0]); + static final Comparator SUGGEST_SCORE_DOC_COMPARATOR = + (a, b) -> { + // compare score, then key (reversed), then docID (reversed) + int cmp = Float.compare(a.score, b.score); + if (cmp != 0) return cmp; + cmp = Lookup.CHARSEQUENCE_COMPARATOR.compare(b.key, a.key); + if (cmp != 0) return cmp; + return Integer.compare(b.doc, a.doc); + }; + /** {@link org.apache.lucene.search.ScoreDoc} with an additional CharSequence key */ public static class SuggestScoreDoc extends ScoreDoc implements Comparable { @@ -102,7 +114,8 @@ public SuggestScoreDoc[] scoreLookupDocs() { *

NOTE: assumes every shardHit is already sorted by score */ public static TopSuggestDocs merge(int topN, TopSuggestDocs[] shardHits) { - SuggestScoreDocPriorityQueue priorityQueue = new SuggestScoreDocPriorityQueue(topN); + PriorityQueue priorityQueue = + PriorityQueue.usingComparator(topN, SUGGEST_SCORE_DOC_COMPARATOR); for (TopSuggestDocs shardHit : shardHits) { for (SuggestScoreDoc scoreDoc : shardHit.scoreLookupDocs()) { if (scoreDoc == priorityQueue.insertWithOverflow(scoreDoc)) { @@ -110,7 +123,7 @@ public static TopSuggestDocs merge(int topN, TopSuggestDocs[] shardHits) { } } } - SuggestScoreDoc[] topNResults = priorityQueue.getResults(); + SuggestScoreDoc[] topNResults = priorityQueue.drainToArrayHighestFirst(SuggestScoreDoc[]::new); if (topNResults.length > 0) { return new TopSuggestDocs( new TotalHits(topNResults.length, TotalHits.Relation.EQUAL_TO), topNResults); diff --git a/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/TopSuggestDocsCollector.java b/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/TopSuggestDocsCollector.java index 7f8532d3ce58..7298caf4cb44 100644 --- a/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/TopSuggestDocsCollector.java +++ b/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/TopSuggestDocsCollector.java @@ -16,11 +16,11 @@ */ package org.apache.lucene.search.suggest.document; +import static org.apache.lucene.search.suggest.document.TopSuggestDocs.SUGGEST_SCORE_DOC_COMPARATOR; import static org.apache.lucene.search.suggest.document.TopSuggestDocs.SuggestScoreDoc; import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import org.apache.lucene.analysis.CharArraySet; import org.apache.lucene.index.LeafReaderContext; @@ -29,6 +29,7 @@ import org.apache.lucene.search.SimpleCollector; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.suggest.Lookup; +import org.apache.lucene.util.PriorityQueue; /** * {@link org.apache.lucene.search.Collector} that collects completion and score, along with @@ -49,7 +50,7 @@ */ public class TopSuggestDocsCollector extends SimpleCollector { - private final SuggestScoreDocPriorityQueue priorityQueue; + private final PriorityQueue priorityQueue; private final int num; /** @@ -77,10 +78,10 @@ public TopSuggestDocsCollector(int num, boolean skipDuplicates) { throw new IllegalArgumentException("'num' must be > 0"); } this.num = num; - this.priorityQueue = new SuggestScoreDocPriorityQueue(num); + this.priorityQueue = PriorityQueue.usingComparator(num, SUGGEST_SCORE_DOC_COMPARATOR); if (skipDuplicates) { seenSurfaceForms = new CharArraySet(num, false); - pendingResults = new ArrayList<>(); + pendingResults = new ArrayList<>(num); } else { seenSurfaceForms = null; pendingResults = null; @@ -105,8 +106,9 @@ protected void doSetNextReader(LeafReaderContext context) throws IOException { @Override public void finish() throws IOException { if (seenSurfaceForms != null) { - // NOTE: this also clears the priorityQueue: - Collections.addAll(pendingResults, priorityQueue.getResults()); + // doesn't need to be sorted now, it is sorted in the get() method + priorityQueue.iterator().forEachRemaining(pendingResults::add); + priorityQueue.clear(); // Deduplicate all hits: we already dedup'd efficiently within each segment by // truncating the FST top paths search, but across segments there may still be dups: @@ -165,17 +167,16 @@ public TopSuggestDocs get() throws IOException { List hits = new ArrayList<>(); for (SuggestScoreDoc hit : pendingResults) { - if (seenSurfaceForms.contains(hit.key) == false) { - seenSurfaceForms.add(hit.key); + if (seenSurfaceForms.add(hit.key)) { hits.add(hit); if (hits.size() == num) { break; } } } - suggestScoreDocs = hits.toArray(new SuggestScoreDoc[0]); + suggestScoreDocs = hits.toArray(SuggestScoreDoc[]::new); } else { - suggestScoreDocs = priorityQueue.getResults(); + suggestScoreDocs = priorityQueue.drainToArrayHighestFirst(SuggestScoreDoc[]::new); } if (suggestScoreDocs.length > 0) { diff --git a/lucene/suggest/src/java/org/apache/lucene/search/suggest/tst/TSTLookup.java b/lucene/suggest/src/java/org/apache/lucene/search/suggest/tst/TSTLookup.java index 547e745fec22..04a59768e582 100644 --- a/lucene/suggest/src/java/org/apache/lucene/search/suggest/tst/TSTLookup.java +++ b/lucene/suggest/src/java/org/apache/lucene/search/suggest/tst/TSTLookup.java @@ -30,6 +30,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.CharsRefBuilder; +import org.apache.lucene.util.PriorityQueue; import org.apache.lucene.util.RamUsageEstimator; /** @@ -191,12 +192,13 @@ public List lookup( } int maxCnt = Math.min(num, list.size()); if (onlyMorePopular) { - LookupPriorityQueue queue = new LookupPriorityQueue(num); + PriorityQueue queue = + PriorityQueue.usingComparator(num, Comparator.comparingLong(lr -> lr.value)); for (TernaryTreeNode ttn : list) { queue.insertWithOverflow(new LookupResult(ttn.token, ((Number) ttn.val).longValue())); } - Collections.addAll(res, queue.getResults()); + Collections.addAll(res, queue.drainToArrayHighestFirst(LookupResult[]::new)); } else { for (int i = 0; i < maxCnt; i++) { TernaryTreeNode ttn = list.get(i);