Skip to content

Implement support for weighted rrf #130658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8be20f0
RRFRetrieverComponent added:
mridula-s109 Jul 4, 2025
a8f6487
Modified parser, toXcontent and included component in the RetrieverBu…
mridula-s109 Jul 4, 2025
e07c38d
[CI] Auto commit changes from spotless
Jul 4, 2025
33d3da4
Resolved merge conflicts
mridula-s109 Jul 15, 2025
5fb5568
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 15, 2025
3ba149c
Fixed compile issues in tests
mridula-s109 Jul 15, 2025
d5749f6
[CI] Auto commit changes from spotless
Jul 15, 2025
7614936
trying to resolve parse errros
mridula-s109 Jul 16, 2025
a5d9e34
wip
ioanatia Jul 17, 2025
0640099
Modified builder
mridula-s109 Jul 17, 2025
cec23c2
[CI] Auto commit changes from spotless
Jul 17, 2025
6da9e15
Removed unnecessary code
mridula-s109 Jul 18, 2025
51b350e
Fixed import
mridula-s109 Jul 18, 2025
4050a3a
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 18, 2025
ea664eb
Enhanced tests
mridula-s109 Jul 18, 2025
98e72be
Fixed the failing tests
mridula-s109 Jul 21, 2025
7de8c7a
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 21, 2025
c778dd2
Yaml tests were added
mridula-s109 Jul 22, 2025
c7b331d
Added cluster features to it
mridula-s109 Jul 22, 2025
f543cbe
Fixed spotless
mridula-s109 Jul 22, 2025
75ab8d0
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 22, 2025
f5086c1
Update docs/changelog/130658.yaml
mridula-s109 Jul 22, 2025
fafb50f
Fixed the relaxed constraints
mridula-s109 Jul 23, 2025
e535864
Resolving issues
mridula-s109 Jul 23, 2025
78f8641
Resolved PR comments
mridula-s109 Jul 23, 2025
02647b1
removed simplified rrf
mridula-s109 Jul 23, 2025
2010f3a
changed the test file back to its original state
mridula-s109 Jul 24, 2025
7433023
Resolved comments to have ahelper method and the test case to use it
mridula-s109 Jul 24, 2025
a2bf4de
made parsing robust
mridula-s109 Jul 24, 2025
eebf577
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 24, 2025
0388abd
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 24, 2025
74ed8db
IT test reverted
mridula-s109 Jul 24, 2025
6d7e8ff
Replaced the declareString array parser
mridula-s109 Jul 25, 2025
f1e14ce
Enforced weights as nonnull
mridula-s109 Jul 25, 2025
fd30387
Fixed the weights null
mridula-s109 Jul 25, 2025
3a82a28
Empty weight shouldnt be serialised
mridula-s109 Jul 25, 2025
77c14d3
[CI] Auto commit changes from spotless
Jul 25, 2025
45ca068
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/130658.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 130658
summary: Implement support for weighted rrf
area: Relevance
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ public Set<NodeFeature> getTestFeatures() {
LINEAR_RETRIEVER_L2_NORM,
LINEAR_RETRIEVER_MINSCORE_FIX,
LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT
RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
RRFRetrieverBuilder.WEIGHTED_SUPPORT
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder.RetrieverSource;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
Expand All @@ -37,7 +38,7 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverComponent.DEFAULT_WEIGHT;

/**
* An rrf retriever is used to represent an rrf rank element, but
Expand All @@ -48,46 +49,50 @@
*/
public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetrieverBuilder> {
public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support");
public static final NodeFeature WEIGHTED_SUPPORT = new NodeFeature("rrf_retriever.weighted_support");

public static final String NAME = "rrf";

public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant");
public static final ParseField FIELDS_FIELD = new ParseField("fields");
public static final ParseField QUERY_FIELD = new ParseField("query");
public static final ParseField WEIGHTS_FIELD = new ParseField("weights");

public static final int DEFAULT_RANK_CONSTANT = 60;

private final float[] weights;

@SuppressWarnings("unchecked")
static final ConstructingObjectParser<RRFRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
NAME,
false,
args -> {
List<RetrieverBuilder> childRetrievers = (List<RetrieverBuilder>) args[0];
List<RRFRetrieverComponent> retrieverComponents = args[0] == null ? List.of() : (List<RRFRetrieverComponent>) args[0];
List<String> fields = (List<String>) args[1];
String query = (String) args[2];
int rankWindowSize = args[3] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
int rankConstant = args[4] == null ? DEFAULT_RANK_CONSTANT : (int) args[4];

List<RetrieverSource> innerRetrievers = childRetrievers != null
? childRetrievers.stream().map(RetrieverSource::from).toList()
: List.of();
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
int n = retrieverComponents.size();
List<RetrieverSource> innerRetrievers = new ArrayList<>(n);
float[] weights = new float[n];
for (int i = 0; i < n; i++) {
RRFRetrieverComponent component = retrieverComponents.get(i);
innerRetrievers.add(RetrieverSource.from(component.retriever()));
weights[i] = component.weight();
}
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant, weights);
}
);

static {
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> {
p.nextToken();
String name = p.currentName();
RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c);
c.trackRetrieverUsage(retrieverBuilder.getName());
p.nextToken();
return retrieverBuilder;
}, RETRIEVERS_FIELD);
PARSER.declareStringArray(optionalConstructorArg(), FIELDS_FIELD);
PARSER.declareString(optionalConstructorArg(), QUERY_FIELD);
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
PARSER.declareInt(optionalConstructorArg(), RANK_CONSTANT_FIELD);
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), RRFRetrieverComponent::fromXContent, RETRIEVERS_FIELD);
PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), FIELDS_FIELD);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), QUERY_FIELD);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_CONSTANT_FIELD);
PARSER.declareFloatArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS_FIELD);
RetrieverBuilder.declareBaseParserFields(PARSER);
}

Expand All @@ -103,27 +108,46 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
private final int rankConstant;

public RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
this(childRetrievers, null, null, rankWindowSize, rankConstant);
this(childRetrievers, null, null, rankWindowSize, rankConstant, createDefaultWeights(childRetrievers));
}

private static float[] createDefaultWeights(List<?> retrievers) {
int size = retrievers == null ? 0 : retrievers.size();
float[] defaultWeights = new float[size];
Arrays.fill(defaultWeights, DEFAULT_WEIGHT);
return defaultWeights;
}

public RRFRetrieverBuilder(
List<RetrieverSource> childRetrievers,
List<String> fields,
String query,
int rankWindowSize,
int rankConstant
int rankConstant,
float[] weights
) {
// Use a mutable list for childRetrievers so that we can use addChild
super(childRetrievers == null ? new ArrayList<>() : new ArrayList<>(childRetrievers), rankWindowSize);
this.fields = fields == null ? null : List.copyOf(fields);
this.query = query;
this.rankConstant = rankConstant;
Objects.requireNonNull(weights, "weights must not be null");
if (weights.length != innerRetrievers.size()) {
throw new IllegalArgumentException(
"weights array length [" + weights.length + "] must match retrievers count [" + innerRetrievers.size() + "]"
);
}
this.weights = weights;
}

public int rankConstant() {
return rankConstant;
}

public float[] weights() {
return weights;
}

@Override
public String getName() {
return NAME;
Expand All @@ -137,6 +161,7 @@ public ActionRequestValidationException validate(
boolean allowPartialSearchResults
) {
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);

return MultiFieldsInnerRetrieverUtils.validateParams(
innerRetrievers,
fields,
Expand All @@ -151,7 +176,14 @@ public ActionRequestValidationException validate(

@Override
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant);
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(
newRetrievers,
this.fields,
this.query,
this.rankWindowSize,
this.rankConstant,
this.weights
);
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
clone.retrieverName = retrieverName;
return clone;
Expand Down Expand Up @@ -183,7 +215,7 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults

// calculate the current rrf score for this document
// later used to sort and covert to a rank
value.score += 1.0f / (rankConstant + frank);
value.score += this.weights[findex] * (1.0f / (rankConstant + frank));

if (explain && value.positions != null && value.scores != null) {
// record the position for each query
Expand Down Expand Up @@ -233,15 +265,18 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
);
}

List<RetrieverSource> fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers(
List<RetrieverBuilder> fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers(
fields,
query,
localIndicesMetadata.values(),
r -> {
List<RetrieverSource> retrievers = r.stream()
.map(MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource::retrieverSource)
.toList();
return new RRFRetrieverBuilder(retrievers, rankWindowSize, rankConstant);
List<RetrieverSource> retrievers = new ArrayList<>(r.size());
for (var retriever : r) {
retrievers.add(retriever.retrieverSource());
}
float[] weights = new float[retrievers.size()];
Arrays.fill(weights, 1.0f);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to hard-code a weight of 1.0 here. The weight validator ensures that every WeightedRetrieverSource's weight is 1.0, so we can use that to populate the weights array.

return new RRFRetrieverBuilder(retrievers, null, null, rankWindowSize, rankConstant, weights);
},
w -> {
if (w != 1.0f) {
Expand All @@ -250,12 +285,20 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
);
}
}
).stream().map(RetrieverSource::from).toList();
);

if (fieldsInnerRetrievers.isEmpty() == false) {
// TODO: This is a incomplete solution as it does not address other incomplete copy issues
// (such as dropping the retriever name and min score)
rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, rankWindowSize, rankConstant);
int size = fieldsInnerRetrievers.size();
List<RetrieverSource> sources = new ArrayList<>(size);
float[] weights = new float[size];
Arrays.fill(weights, RRFRetrieverComponent.DEFAULT_WEIGHT);
for (int i = 0; i < size; i++) {
sources.add(RetrieverSource.from(fieldsInnerRetrievers.get(i)));
weights[i] = RRFRetrieverComponent.DEFAULT_WEIGHT;
}
rewritten = new RRFRetrieverBuilder(sources, null, null, rankWindowSize, rankConstant, weights);
Comment on lines -253 to +301
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a functional difference, but we could clean this up a bit and remove some duplicated logic:

  • Keep .stream().map(RetrieverSource::from).toList() to build a list of RetrieverSources. We need that anyways.
  • Use createDefaultWeights to create the weights array.

rewritten.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
} else {
// Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
Expand All @@ -266,29 +309,13 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
return rewritten;
}

// ---- FOR TESTING XCONTENT PARSING ----

@Override
public boolean doEquals(Object o) {
RRFRetrieverBuilder that = (RRFRetrieverBuilder) o;
return super.doEquals(o)
&& Objects.equals(fields, that.fields)
&& Objects.equals(query, that.query)
&& rankConstant == that.rankConstant;
}

@Override
public int doHashCode() {
return Objects.hash(super.doHashCode(), fields, query, rankConstant);
}

@Override
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
if (innerRetrievers.isEmpty() == false) {
builder.startArray(RETRIEVERS_FIELD.getPreferredName());

for (var entry : innerRetrievers) {
entry.retriever().toXContent(builder, params);
for (int i = 0; i < innerRetrievers.size(); i++) {
RRFRetrieverComponent component = new RRFRetrieverComponent(innerRetrievers.get(i).retriever(), weights[i]);
component.toXContent(builder, params);
}
builder.endArray();
}
Expand All @@ -306,5 +333,28 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept

builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant);
if (weights.length > 0) {
builder.startArray(WEIGHTS_FIELD.getPreferredName());
for (float weight : weights) {
builder.value(weight);
}
builder.endArray();
}
}

// ---- FOR TESTING XCONTENT PARSING ----
@Override
public boolean doEquals(Object o) {
RRFRetrieverBuilder that = (RRFRetrieverBuilder) o;
return super.doEquals(o)
&& Objects.equals(fields, that.fields)
&& Objects.equals(query, that.query)
&& rankConstant == that.rankConstant
&& Arrays.equals(weights, that.weights);
}

@Override
public int doHashCode() {
return Objects.hash(super.doHashCode(), fields, query, rankConstant, Arrays.hashCode(weights));
}
}
Loading
Loading