Skip to content

Commit 3c8c526

Browse files
benwtrentmridula-s109
authored andcommitted
Fix sampling for kmeans and address assignment edge case (elastic#130405)
This is three fixes: - We should be doing actual sampling when doing kmeans clustering, taking the first N vectors creates some weird edge cases - Having assignments initialized as `0` means that if a vector gets assigned to cluster ord `0`, that cluster centroid actually isn't updated later in the lloyd steps. So, this initializes assignments to -1 - If we actually don't sample the vectors for lloyd, don't bother with final pass to potentially update the centroids
1 parent 3c1f4f6 commit 3c8c526

File tree

3 files changed

+35
-15
lines changed

3 files changed

+35
-15
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/SampleReader.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525
import org.apache.lucene.util.Bits;
2626

2727
import java.io.IOException;
28+
import java.util.Arrays;
2829
import java.util.Random;
2930
import java.util.function.IntUnaryOperator;
3031

31-
class SampleReader extends FloatVectorValues implements HasIndexSlice {
32+
public class SampleReader extends FloatVectorValues implements HasIndexSlice {
3233
private final FloatVectorValues origin;
3334
private final int sampleSize;
3435
private final IntUnaryOperator sampleFunction;
@@ -71,21 +72,24 @@ public int getVectorByteLength() {
7172

7273
@Override
7374
public int ordToDoc(int ord) {
74-
throw new IllegalStateException("Not supported");
75+
// get the original ordinal from the sample ordinal
76+
return sampleFunction.applyAsInt(ord);
7577
}
7678

7779
@Override
7880
public Bits getAcceptOrds(Bits acceptDocs) {
7981
throw new IllegalStateException("Not supported");
8082
}
8183

82-
static SampleReader createSampleReader(FloatVectorValues origin, int k, long seed) {
84+
public static SampleReader createSampleReader(FloatVectorValues origin, int k, long seed) {
8385
// TODO can we do something algorithmically that aligns an ordinal with a unique integer between 0 and numVectors?
8486
if (k >= origin.size()) {
8587
new SampleReader(origin, origin.size(), i -> i);
8688
}
8789
// TODO maybe use bigArrays?
8890
int[] samples = reservoirSample(origin.size(), k, seed);
91+
// sort to prevent random backwards access weirdness
92+
Arrays.sort(samples);
8993
return new SampleReader(origin, samples.length, i -> samples[i]);
9094
}
9195

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.apache.lucene.index.FloatVectorValues;
1313

1414
import java.io.IOException;
15+
import java.util.Arrays;
1516

1617
/**
1718
* An implementation of the hierarchical k-means algorithm that better partitions data than naive k-means
@@ -84,6 +85,8 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
8485

8586
// TODO: instead of creating a sub-cluster assignments reuse the parent array each time
8687
int[] assignments = new int[vectors.size()];
88+
// ensure we don't over assign to cluster 0 without adjusting it
89+
Arrays.fill(assignments, -1);
8790
KMeansLocal kmeans = new KMeansLocal(m, maxIterations);
8891
float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k);
8992
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
import org.apache.lucene.index.FloatVectorValues;
1313
import org.apache.lucene.util.VectorUtil;
14+
import org.apache.lucene.util.hnsw.IntToIntFunction;
15+
import org.elasticsearch.index.codec.vectors.SampleReader;
1416
import org.elasticsearch.simdvec.ESVectorUtil;
1517

1618
import java.io.IOException;
@@ -74,12 +76,12 @@ static float[][] pickInitialCentroids(FloatVectorValues vectors, int centroidCou
7476
return centroids;
7577
}
7678

77-
private boolean stepLloyd(
79+
private static boolean stepLloyd(
7880
FloatVectorValues vectors,
81+
IntToIntFunction translateOrd,
7982
float[][] centroids,
8083
float[][] nextCentroids,
8184
int[] assignments,
82-
int sampleSize,
8385
List<int[]> neighborhoods
8486
) throws IOException {
8587
boolean changed = false;
@@ -90,17 +92,18 @@ private boolean stepLloyd(
9092
Arrays.fill(nextCentroid, 0.0f);
9193
}
9294

93-
for (int i = 0; i < sampleSize; i++) {
94-
float[] vector = vectors.vectorValue(i);
95-
final int assignment = assignments[i];
95+
for (int idx = 0; idx < vectors.size(); idx++) {
96+
float[] vector = vectors.vectorValue(idx);
97+
int vectorOrd = translateOrd.apply(idx);
98+
final int assignment = assignments[vectorOrd];
9699
final int bestCentroidOffset;
97100
if (neighborhoods != null) {
98101
bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods.get(assignment));
99102
} else {
100103
bestCentroidOffset = getBestCentroid(centroids, vector);
101104
}
102105
if (assignment != bestCentroidOffset) {
103-
assignments[i] = bestCentroidOffset;
106+
assignments[vectorOrd] = bestCentroidOffset;
104107
changed = true;
105108
}
106109
centroidCounts[bestCentroidOffset]++;
@@ -121,7 +124,7 @@ private boolean stepLloyd(
121124
return changed;
122125
}
123126

124-
int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
127+
private static int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
125128
int bestCentroidOffset = centroidIdx;
126129
assert centroidIdx >= 0 && centroidIdx < centroids.length;
127130
float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
@@ -135,7 +138,7 @@ int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centr
135138
return bestCentroidOffset;
136139
}
137140

138-
int getBestCentroid(float[][] centroids, float[] vector) {
141+
private static int getBestCentroid(float[][] centroids, float[] vector) {
139142
int bestCentroidOffset = 0;
140143
float minDsq = Float.MAX_VALUE;
141144
for (int i = 0; i < centroids.length; i++) {
@@ -281,24 +284,34 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
281284
}
282285
}
283286

284-
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List<int[]> neighborhoods) throws IOException {
287+
private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List<int[]> neighborhoods) throws IOException {
285288
float[][] centroids = kMeansIntermediate.centroids();
286289
int k = centroids.length;
287290
int n = vectors.size();
288291

289292
if (k == 1 || k >= n) {
290293
return;
291294
}
292-
295+
IntToIntFunction translateOrd = i -> i;
296+
FloatVectorValues sampledVectors = vectors;
297+
if (sampleSize < n) {
298+
sampledVectors = SampleReader.createSampleReader(vectors, sampleSize, 42L);
299+
translateOrd = sampledVectors::ordToDoc;
300+
}
293301
int[] assignments = kMeansIntermediate.assignments();
294302
assert assignments.length == n;
295303
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
296304
for (int i = 0; i < maxIterations; i++) {
297-
if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) {
305+
// This is potentially sampled, so we need to translate ordinals
306+
if (stepLloyd(sampledVectors, translateOrd, centroids, nextCentroids, assignments, neighborhoods) == false) {
298307
break;
299308
}
300309
}
301-
stepLloyd(vectors, centroids, nextCentroids, assignments, vectors.size(), neighborhoods);
310+
// If we were sampled, do a once over the full set of vectors to finalize the centroids
311+
if (sampleSize < n) {
312+
// No ordinal translation needed here, we are using the full set of vectors
313+
stepLloyd(vectors, i -> i, centroids, nextCentroids, assignments, neighborhoods);
314+
}
302315
}
303316

304317
/**

0 commit comments

Comments
 (0)