-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Implement support for weighted rrf #130658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement support for weighted rrf #130658
Conversation
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java
Outdated
Show resolved
Hide resolved
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java
Outdated
Show resolved
Hide resolved
7614936
to
e7f0a90
Compare
782e3ca
to
f1eede5
Compare
f1eede5
to
f5e1572
Compare
f1eede5
to
0640099
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice progress, this is looking better. The hybrid object parsing is definitely a beast.
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java
Outdated
Show resolved
Hide resolved
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java
Outdated
Show resolved
Hide resolved
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java
Outdated
Show resolved
Hide resolved
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java
Outdated
Show resolved
Hide resolved
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java
Outdated
Show resolved
Hide resolved
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java
Outdated
Show resolved
Hide resolved
static final ConstructingObjectParser<RRFRetrieverComponent, RetrieverParserContext> PARSER = new ConstructingObjectParser<>( | ||
"rrf_component", | ||
false, | ||
(args, context) -> { | ||
RetrieverBuilder retrieverBuilder = (RetrieverBuilder) args[0]; | ||
Float weight = (Float) args[1]; | ||
return new RRFRetrieverComponent(retrieverBuilder, weight); | ||
} | ||
); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do we actually use this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is used in the static method:
public static RRFRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context)
which typically calls:
return PARSER.apply(parser, context);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get that's where it's normally used, but where is it actually used in this implementation? RRFRetrieverComponen#fromXContent
, in its current form, doesn't use it.
if (RETRIEVER_FIELD.match(firstFieldName, parser.getDeprecationHandler()) | ||
|| WEIGHT_FIELD.match(firstFieldName, parser.getDeprecationHandler())) { | ||
// This is a structured component - parse manually | ||
RetrieverBuilder retriever = null; | ||
Float weight = null; | ||
|
||
do { | ||
String fieldName = parser.currentName(); | ||
if (RETRIEVER_FIELD.match(fieldName, parser.getDeprecationHandler())) { | ||
if (retriever != null) { | ||
throw new ParsingException(parser.getTokenLocation(), "only one retriever can be specified"); | ||
} | ||
parser.nextToken(); | ||
parser.nextToken(); | ||
String retrieverType = parser.currentName(); | ||
retriever = parser.namedObject(RetrieverBuilder.class, retrieverType, context); | ||
context.trackRetrieverUsage(retriever.getName()); | ||
parser.nextToken(); | ||
} else if (WEIGHT_FIELD.match(fieldName, parser.getDeprecationHandler())) { | ||
if (weight != null) { | ||
throw new ParsingException(parser.getTokenLocation(), "[weight] field can only be specified once"); | ||
} | ||
parser.nextToken(); | ||
weight = parser.floatValue(); | ||
} else { | ||
if (retriever != null) { | ||
throw new ParsingException(parser.getTokenLocation(), "only one retriever can be specified"); | ||
} | ||
throw new ParsingException( | ||
parser.getTokenLocation(), | ||
"unknown field [{}], expected [{}] or [{}]", | ||
fieldName, | ||
RETRIEVER_FIELD.getPreferredName(), | ||
WEIGHT_FIELD.getPreferredName() | ||
); | ||
} | ||
} while (parser.nextToken() == XContentParser.Token.FIELD_NAME); | ||
|
||
if (retriever == null) { | ||
throw new ParsingException(parser.getTokenLocation(), "retriever component must contain a retriever"); | ||
} | ||
|
||
return new RRFRetrieverComponent(retriever, weight); | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand this is complex, but is there an opportunity to use a ConstructingObjectParser
here once we know this is a structured component? @ioanatia WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am bit skeptical on that change, lets see what @ioanatia thinks as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit of cleanup left, but this is coming along 👍
retrievers.add(retriever.retrieverSource()); | ||
} | ||
float[] weights = new float[retrievers.size()]; | ||
Arrays.fill(weights, 1.0f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to hard-code a weight of 1.0 here. The weight validator ensures that every WeightedRetrieverSource
's weight is 1.0, so we can use that to populate the weights
array.
).stream().map(RetrieverSource::from).toList(); | ||
); | ||
|
||
if (fieldsInnerRetrievers.isEmpty() == false) { | ||
// TODO: This is a incomplete solution as it does not address other incomplete copy issues | ||
// (such as dropping the retriever name and min score) | ||
rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, rankWindowSize, rankConstant); | ||
int size = fieldsInnerRetrievers.size(); | ||
List<RetrieverSource> sources = new ArrayList<>(size); | ||
float[] weights = new float[size]; | ||
Arrays.fill(weights, RRFRetrieverComponent.DEFAULT_WEIGHT); | ||
for (int i = 0; i < size; i++) { | ||
sources.add(RetrieverSource.from(fieldsInnerRetrievers.get(i))); | ||
weights[i] = RRFRetrieverComponent.DEFAULT_WEIGHT; | ||
} | ||
rewritten = new RRFRetrieverBuilder(sources, null, null, rankWindowSize, rankConstant, weights); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a functional difference, but we could clean this up a bit and remove some duplicated logic:
- Keep
.stream().map(RetrieverSource::from).toList()
to build a list ofRetrieverSource
s. We need that anyways. - Use
createDefaultWeights
to create theweights
array.
static final ConstructingObjectParser<RRFRetrieverComponent, RetrieverParserContext> PARSER = new ConstructingObjectParser<>( | ||
"rrf_component", | ||
false, | ||
(args, context) -> { | ||
RetrieverBuilder retrieverBuilder = (RetrieverBuilder) args[0]; | ||
Float weight = (Float) args[1]; | ||
return new RRFRetrieverComponent(retrieverBuilder, weight); | ||
} | ||
); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get that's where it's normally used, but where is it actually used in this implementation? RRFRetrieverComponen#fromXContent
, in its current form, doesn't use it.
for (int i = 0; i < innerRetrievers.size(); i++) { | ||
weights[i] = randomFloat(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: we could populate the weights
array in the while (retrieverCount > 0)
loop
public void testRRFRetrieverParsingWithDefaultWeights() throws IOException { | ||
String restContent = """ | ||
{ | ||
"retriever": { | ||
"rrf": { | ||
"retrievers": [ | ||
{ | ||
"test": { | ||
"value": "first" | ||
} | ||
}, | ||
{ | ||
"test": { | ||
"value": "second" | ||
} | ||
} | ||
], | ||
"rank_window_size": 100, | ||
"rank_constant": 10, | ||
"min_score": 20.0, | ||
"_name": "foo_rrf" | ||
} | ||
} | ||
} | ||
"""; | ||
checkRRFRetrieverParsing(restContent); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this test functionally different than testRRFRetrieverParsing
?
} | ||
"""; | ||
|
||
expectParsingException(negativeWeightContent, "weight] must be non-negative"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Missing a [
here
try (XContentParser parser = createParser(JsonXContent.jsonXContent, legacyJson)) { | ||
SearchSourceBuilder ssb = new SearchSourceBuilder().parseXContent(parser, true, nf -> true); | ||
assertThat(ssb.retriever(), instanceOf(RRFRetrieverBuilder.class)); | ||
RRFRetrieverBuilder rrf = (RRFRetrieverBuilder) ssb.retriever(); | ||
assertArrayEquals(new float[] { 1.0f, 1.0f }, rrf.weights(), 0.001f); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: we could factor this duplicated code out. It doesn't necessarily need to be refactored into a named method, we could make a local BiConsumer
(or CheckedBiConsumer
) that takes the JSON string and the expected weight array.
- match: { hits.hits.0._id: "1" } | ||
|
||
--- | ||
"Weighted RRF retriever defaults to weight 1": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very big nit here, but this test method is misleading. It doesn't actually check that that the default weight is one (testRRFRetrieverParsingSyntax
does that). It only checks that weight
is optional. Maybe change the test name to be more accurate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a test where we show that we can use weight
to boost a document in the result set? In other words, show that by changing weight
we can change the result order in an expected way.
Implement support for weighted RRF
Summary
This PR implements support for weighted RRF (Reciprocal Rank Fusion) retrievers, allowing users to specify custom weights for each sub-retriever within an RRF retriever configuration. This addresses a common customer request to customize the influence of different retrievers in the RRF scoring process.
Core Implementation
Key Features
1. Weighted Retriever Support
Users can now specify weights for individual retrievers: