-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Implement support for weighted rrf #130658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8be20f0
a8f6487
e07c38d
33d3da4
5fb5568
3ba149c
d5749f6
7614936
a5d9e34
0640099
cec23c2
6da9e15
51b350e
4050a3a
ea664eb
98e72be
7de8c7a
c778dd2
c7b331d
f543cbe
75ab8d0
f5086c1
fafb50f
e535864
78f8641
02647b1
2010f3a
7433023
a2bf4de
eebf577
0388abd
74ed8db
6d7e8ff
f1e14ce
fd30387
3a82a28
77c14d3
45ca068
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 130658 | ||
summary: Implement support for weighted rrf | ||
area: Relevance | ||
type: enhancement | ||
issues: [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
import org.elasticsearch.search.rank.RankBuilder; | ||
import org.elasticsearch.search.rank.RankDoc; | ||
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; | ||
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder.RetrieverSource; | ||
import org.elasticsearch.search.retriever.RetrieverBuilder; | ||
import org.elasticsearch.search.retriever.RetrieverParserContext; | ||
import org.elasticsearch.search.retriever.StandardRetrieverBuilder; | ||
|
@@ -37,7 +38,7 @@ | |
import java.util.Map; | ||
import java.util.Objects; | ||
|
||
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; | ||
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverComponent.DEFAULT_WEIGHT; | ||
|
||
/** | ||
* An rrf retriever is used to represent an rrf rank element, but | ||
|
@@ -48,46 +49,50 @@ | |
*/ | ||
public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetrieverBuilder> { | ||
public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support"); | ||
public static final NodeFeature WEIGHTED_SUPPORT = new NodeFeature("rrf_retriever.weighted_support"); | ||
|
||
public static final String NAME = "rrf"; | ||
|
||
public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers"); | ||
public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant"); | ||
public static final ParseField FIELDS_FIELD = new ParseField("fields"); | ||
public static final ParseField QUERY_FIELD = new ParseField("query"); | ||
public static final ParseField WEIGHTS_FIELD = new ParseField("weights"); | ||
|
||
public static final int DEFAULT_RANK_CONSTANT = 60; | ||
|
||
private final float[] weights; | ||
|
||
@SuppressWarnings("unchecked") | ||
static final ConstructingObjectParser<RRFRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>( | ||
NAME, | ||
false, | ||
args -> { | ||
List<RetrieverBuilder> childRetrievers = (List<RetrieverBuilder>) args[0]; | ||
List<RRFRetrieverComponent> retrieverComponents = args[0] == null ? List.of() : (List<RRFRetrieverComponent>) args[0]; | ||
List<String> fields = (List<String>) args[1]; | ||
String query = (String) args[2]; | ||
int rankWindowSize = args[3] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[3]; | ||
int rankConstant = args[4] == null ? DEFAULT_RANK_CONSTANT : (int) args[4]; | ||
|
||
List<RetrieverSource> innerRetrievers = childRetrievers != null | ||
? childRetrievers.stream().map(RetrieverSource::from).toList() | ||
: List.of(); | ||
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant); | ||
int n = retrieverComponents.size(); | ||
List<RetrieverSource> innerRetrievers = new ArrayList<>(n); | ||
float[] weights = new float[n]; | ||
for (int i = 0; i < n; i++) { | ||
RRFRetrieverComponent component = retrieverComponents.get(i); | ||
innerRetrievers.add(RetrieverSource.from(component.retriever())); | ||
weights[i] = component.weight(); | ||
} | ||
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant, weights); | ||
} | ||
); | ||
|
||
static { | ||
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> { | ||
p.nextToken(); | ||
String name = p.currentName(); | ||
RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c); | ||
c.trackRetrieverUsage(retrieverBuilder.getName()); | ||
p.nextToken(); | ||
return retrieverBuilder; | ||
}, RETRIEVERS_FIELD); | ||
PARSER.declareStringArray(optionalConstructorArg(), FIELDS_FIELD); | ||
PARSER.declareString(optionalConstructorArg(), QUERY_FIELD); | ||
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); | ||
PARSER.declareInt(optionalConstructorArg(), RANK_CONSTANT_FIELD); | ||
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), RRFRetrieverComponent::fromXContent, RETRIEVERS_FIELD); | ||
PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), FIELDS_FIELD); | ||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), QUERY_FIELD); | ||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); | ||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_CONSTANT_FIELD); | ||
PARSER.declareFloatArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS_FIELD); | ||
RetrieverBuilder.declareBaseParserFields(PARSER); | ||
} | ||
|
||
|
@@ -103,27 +108,46 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP | |
private final int rankConstant; | ||
|
||
public RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) { | ||
this(childRetrievers, null, null, rankWindowSize, rankConstant); | ||
this(childRetrievers, null, null, rankWindowSize, rankConstant, createDefaultWeights(childRetrievers)); | ||
} | ||
|
||
private static float[] createDefaultWeights(List<?> retrievers) { | ||
int size = retrievers == null ? 0 : retrievers.size(); | ||
float[] defaultWeights = new float[size]; | ||
Arrays.fill(defaultWeights, DEFAULT_WEIGHT); | ||
return defaultWeights; | ||
} | ||
|
||
public RRFRetrieverBuilder( | ||
List<RetrieverSource> childRetrievers, | ||
List<String> fields, | ||
String query, | ||
int rankWindowSize, | ||
int rankConstant | ||
int rankConstant, | ||
float[] weights | ||
) { | ||
// Use a mutable list for childRetrievers so that we can use addChild | ||
super(childRetrievers == null ? new ArrayList<>() : new ArrayList<>(childRetrievers), rankWindowSize); | ||
this.fields = fields == null ? null : List.copyOf(fields); | ||
this.query = query; | ||
this.rankConstant = rankConstant; | ||
Objects.requireNonNull(weights, "weights must not be null"); | ||
if (weights.length != innerRetrievers.size()) { | ||
throw new IllegalArgumentException( | ||
"weights array length [" + weights.length + "] must match retrievers count [" + innerRetrievers.size() + "]" | ||
); | ||
} | ||
this.weights = weights; | ||
} | ||
|
||
public int rankConstant() { | ||
return rankConstant; | ||
} | ||
|
||
public float[] weights() { | ||
return weights; | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return NAME; | ||
|
@@ -137,6 +161,7 @@ public ActionRequestValidationException validate( | |
boolean allowPartialSearchResults | ||
) { | ||
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults); | ||
|
||
return MultiFieldsInnerRetrieverUtils.validateParams( | ||
innerRetrievers, | ||
fields, | ||
|
@@ -151,7 +176,14 @@ public ActionRequestValidationException validate( | |
|
||
@Override | ||
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) { | ||
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant); | ||
RRFRetrieverBuilder clone = new RRFRetrieverBuilder( | ||
newRetrievers, | ||
this.fields, | ||
this.query, | ||
this.rankWindowSize, | ||
this.rankConstant, | ||
this.weights | ||
); | ||
clone.preFilterQueryBuilders = newPreFilterQueryBuilders; | ||
clone.retrieverName = retrieverName; | ||
return clone; | ||
|
@@ -183,7 +215,7 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults | |
|
||
// calculate the current rrf score for this document | ||
// later used to sort and covert to a rank | ||
value.score += 1.0f / (rankConstant + frank); | ||
value.score += this.weights[findex] * (1.0f / (rankConstant + frank)); | ||
|
||
if (explain && value.positions != null && value.scores != null) { | ||
// record the position for each query | ||
|
@@ -233,15 +265,18 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { | |
); | ||
} | ||
|
||
List<RetrieverSource> fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers( | ||
List<RetrieverBuilder> fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers( | ||
fields, | ||
query, | ||
localIndicesMetadata.values(), | ||
r -> { | ||
List<RetrieverSource> retrievers = r.stream() | ||
.map(MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource::retrieverSource) | ||
.toList(); | ||
return new RRFRetrieverBuilder(retrievers, rankWindowSize, rankConstant); | ||
List<RetrieverSource> retrievers = new ArrayList<>(r.size()); | ||
for (var retriever : r) { | ||
retrievers.add(retriever.retrieverSource()); | ||
} | ||
float[] weights = new float[retrievers.size()]; | ||
Arrays.fill(weights, 1.0f); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to hard-code a weight of 1.0 here. The weight validator ensures that every |
||
return new RRFRetrieverBuilder(retrievers, null, null, rankWindowSize, rankConstant, weights); | ||
}, | ||
w -> { | ||
if (w != 1.0f) { | ||
|
@@ -250,12 +285,20 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { | |
); | ||
} | ||
} | ||
).stream().map(RetrieverSource::from).toList(); | ||
); | ||
|
||
if (fieldsInnerRetrievers.isEmpty() == false) { | ||
// TODO: This is a incomplete solution as it does not address other incomplete copy issues | ||
// (such as dropping the retriever name and min score) | ||
rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, rankWindowSize, rankConstant); | ||
int size = fieldsInnerRetrievers.size(); | ||
List<RetrieverSource> sources = new ArrayList<>(size); | ||
float[] weights = new float[size]; | ||
Arrays.fill(weights, RRFRetrieverComponent.DEFAULT_WEIGHT); | ||
for (int i = 0; i < size; i++) { | ||
sources.add(RetrieverSource.from(fieldsInnerRetrievers.get(i))); | ||
weights[i] = RRFRetrieverComponent.DEFAULT_WEIGHT; | ||
} | ||
rewritten = new RRFRetrieverBuilder(sources, null, null, rankWindowSize, rankConstant, weights); | ||
Comment on lines
-253
to
+301
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a functional difference, but we could clean this up a bit and remove some duplicated logic:
|
||
rewritten.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); | ||
} else { | ||
// Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices | ||
|
@@ -266,29 +309,13 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { | |
return rewritten; | ||
} | ||
|
||
// ---- FOR TESTING XCONTENT PARSING ---- | ||
|
||
@Override | ||
public boolean doEquals(Object o) { | ||
RRFRetrieverBuilder that = (RRFRetrieverBuilder) o; | ||
return super.doEquals(o) | ||
&& Objects.equals(fields, that.fields) | ||
&& Objects.equals(query, that.query) | ||
&& rankConstant == that.rankConstant; | ||
} | ||
|
||
@Override | ||
public int doHashCode() { | ||
return Objects.hash(super.doHashCode(), fields, query, rankConstant); | ||
} | ||
|
||
@Override | ||
public void doToXContent(XContentBuilder builder, Params params) throws IOException { | ||
if (innerRetrievers.isEmpty() == false) { | ||
builder.startArray(RETRIEVERS_FIELD.getPreferredName()); | ||
|
||
for (var entry : innerRetrievers) { | ||
entry.retriever().toXContent(builder, params); | ||
for (int i = 0; i < innerRetrievers.size(); i++) { | ||
RRFRetrieverComponent component = new RRFRetrieverComponent(innerRetrievers.get(i).retriever(), weights[i]); | ||
component.toXContent(builder, params); | ||
} | ||
builder.endArray(); | ||
} | ||
|
@@ -306,5 +333,28 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept | |
|
||
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); | ||
builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant); | ||
if (weights.length > 0) { | ||
builder.startArray(WEIGHTS_FIELD.getPreferredName()); | ||
for (float weight : weights) { | ||
builder.value(weight); | ||
} | ||
builder.endArray(); | ||
} | ||
} | ||
|
||
// ---- FOR TESTING XCONTENT PARSING ---- | ||
@Override | ||
public boolean doEquals(Object o) { | ||
RRFRetrieverBuilder that = (RRFRetrieverBuilder) o; | ||
return super.doEquals(o) | ||
&& Objects.equals(fields, that.fields) | ||
&& Objects.equals(query, that.query) | ||
&& rankConstant == that.rankConstant | ||
&& Arrays.equals(weights, that.weights); | ||
} | ||
|
||
@Override | ||
public int doHashCode() { | ||
return Objects.hash(super.doHashCode(), fields, query, rankConstant, Arrays.hashCode(weights)); | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.