Skip to content

Commit a5d9e34

Browse files
committed
wip
1 parent 7614936 commit a5d9e34

File tree

3 files changed

+65
-27
lines changed

3 files changed

+65
-27
lines changed

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.elasticsearch.xcontent.XContentParser;
3030
import org.elasticsearch.xpack.core.XPackPlugin;
3131
import org.elasticsearch.xpack.rank.MultiFieldsInnerRetrieverUtils;
32+
import org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent;
3233

3334
import java.io.IOException;
3435
import java.util.ArrayList;
@@ -39,6 +40,7 @@
3940

4041
import static org.elasticsearch.action.ValidateActions.addValidationError;
4142
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverComponent.DEFAULT_WEIGHT;
43+
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverComponent.RETRIEVER_FIELD;
4244

4345
/**
4446
* An rrf retriever is used to represent an rrf rank element, but
@@ -67,39 +69,28 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
6769
NAME,
6870
true, // Set to true to ignore unknown fields
6971
args -> {
70-
List<RetrieverBuilder> retrievers = args[0] == null ? List.of() : (List<RetrieverBuilder>) args[0];
72+
List<RRFRetrieverComponent> retrieverComponents = args[0] == null ? List.of() : (List<RRFRetrieverComponent>) args[0];
7173
List<String> fields = (List<String>) args[1];
7274
String query = (String) args[2];
7375
int rankWindowSize = args[3] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
7476
int rankConstant = args[4] == null ? DEFAULT_RANK_CONSTANT : (int) args[4];
75-
List<Float> weightsList = (List<Float>) args[5];
7677

77-
List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers = retrievers.stream()
78-
.map(CompoundRetrieverBuilder.RetrieverSource::from)
79-
.toList();
78+
float[] weights = new float[retrieverComponents.size()];
8079

81-
float[] weights;
82-
if (weightsList == null) {
83-
weights = new float[retrievers.size()];
84-
Arrays.fill(weights, DEFAULT_WEIGHT);
85-
} else {
86-
weights = new float[weightsList.size()];
87-
for (int i = 0; i < weightsList.size(); i++) {
88-
weights[i] = weightsList.get(i);
89-
}
80+
int index = 0;
81+
List<RetrieverSource> innerRetrievers = new ArrayList<>();
82+
for (RRFRetrieverComponent component : retrieverComponents) {
83+
innerRetrievers.add(RetrieverSource.from(component.retriever));
84+
weights[index] = component.weight;
85+
index++;
9086
}
9187

9288
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant, weights);
9389
}
9490
);
9591

9692
static {
97-
PARSER.declareNamedObjects(
98-
ConstructingObjectParser.optionalConstructorArg(),
99-
(p, c, n) -> p.namedObject(RetrieverBuilder.class, n, c),
100-
(v) -> { /* This callback enables array syntax */ },
101-
RETRIEVERS_FIELD
102-
);
93+
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), RRFRetrieverComponent::fromXContent, RETRIEVERS_FIELD);
10394
PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), FIELDS_FIELD);
10495
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), QUERY_FIELD);
10596
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RankBuilder.RANK_WINDOW_SIZE_FIELD);

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverComponent.java

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,24 @@
77

88
package org.elasticsearch.xpack.rank.rrf;
99

10+
import org.elasticsearch.common.ParsingException;
1011
import org.elasticsearch.search.retriever.RetrieverBuilder;
1112
import org.elasticsearch.search.retriever.RetrieverParserContext;
1213
import org.elasticsearch.xcontent.ConstructingObjectParser;
14+
import org.elasticsearch.xcontent.CopyingXContentParser;
15+
import org.elasticsearch.xcontent.ObjectParser;
1316
import org.elasticsearch.xcontent.ParseField;
1417
import org.elasticsearch.xcontent.ToXContentObject;
1518
import org.elasticsearch.xcontent.XContentBuilder;
19+
import org.elasticsearch.xcontent.XContentFactory;
20+
import org.elasticsearch.xcontent.XContentParseException;
1621
import org.elasticsearch.xcontent.XContentParser;
22+
import org.elasticsearch.xcontent.XContentSubParser;
23+
import org.elasticsearch.xcontent.XContentType;
1724

1825
import java.io.IOException;
26+
import java.util.HashMap;
27+
import java.util.Map;
1928

2029
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
2130
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
@@ -24,7 +33,6 @@ public class RRFRetrieverComponent implements ToXContentObject {
2433

2534
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
2635
public static final ParseField WEIGHT_FIELD = new ParseField("weight");
27-
2836
static final float DEFAULT_WEIGHT = 1f;
2937

3038
final RetrieverBuilder retriever;
@@ -75,6 +83,37 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para
7583
}
7684

7785
public static RRFRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
78-
return PARSER.apply(parser, context);
86+
RetrieverBuilder innerRetriever = null;
87+
float weight = DEFAULT_WEIGHT;
88+
89+
if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
90+
throw new ParsingException(parser.getTokenLocation(), "[{}] expected object", parser.currentToken());
91+
}
92+
93+
while ((parser.nextToken()) != XContentParser.Token.END_OBJECT) {
94+
var name = parser.currentName();
95+
96+
if (name.equals(RETRIEVER_FIELD.getPreferredName())) {
97+
if (parser.nextToken() != XContentParser.Token.START_OBJECT) {
98+
throw new ParsingException(parser.getTokenLocation(), "[{}] expected object", parser.currentToken());
99+
}
100+
parser.nextToken();
101+
102+
name = parser.currentName();
103+
innerRetriever = parser.namedObject(RetrieverBuilder.class, name, context);
104+
parser.nextToken();
105+
} else if (name.equals(WEIGHT_FIELD.getPreferredName())) {
106+
if (parser.nextToken() != XContentParser.Token.VALUE_NUMBER) {
107+
throw new ParsingException(parser.getTokenLocation(), "[{}] expected number", parser.currentToken());
108+
}
109+
110+
weight = parser.floatValue();
111+
} else {
112+
innerRetriever = parser.namedObject(RetrieverBuilder.class, name, context);
113+
parser.nextToken();
114+
break;
115+
}
116+
}
117+
return new RRFRetrieverComponent(innerRetriever, weight);
79118
}
80119
}

x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,28 @@ public void testRRFRetrieverParsing() throws IOException {
115115
"retrievers": [
116116
{
117117
"test": {
118-
"value": "foo"
118+
"value": "foobar"
119119
}
120120
},
121121
{
122-
"test": {
123-
"value": "bar"
122+
"retriever": {
123+
"test": {
124+
"value": "foo"
125+
}
124126
}
127+
},
128+
{
129+
"retriever": {
130+
"test": {
131+
"value": "bar"
132+
}
133+
},
134+
"weight": 1.3
125135
}
126136
],
127-
"fields": ["field1", "field2"],
128137
"query": "baz",
129138
"rank_window_size": 100,
130139
"rank_constant": 10,
131-
"weights": [0.5, 0.5],
132140
"min_score": 20.0,
133141
"_name": "foo_rrf"
134142
}

0 commit comments

Comments
 (0)