11
11
12
12
import org .apache .lucene .index .FloatVectorValues ;
13
13
import org .apache .lucene .util .VectorUtil ;
14
+ import org .apache .lucene .util .hnsw .IntToIntFunction ;
15
+ import org .elasticsearch .index .codec .vectors .SampleReader ;
14
16
import org .elasticsearch .simdvec .ESVectorUtil ;
15
17
16
18
import java .io .IOException ;
@@ -74,12 +76,12 @@ static float[][] pickInitialCentroids(FloatVectorValues vectors, int centroidCou
74
76
return centroids ;
75
77
}
76
78
77
- private boolean stepLloyd (
79
+ private static boolean stepLloyd (
78
80
FloatVectorValues vectors ,
81
+ IntToIntFunction translateOrd ,
79
82
float [][] centroids ,
80
83
float [][] nextCentroids ,
81
84
int [] assignments ,
82
- int sampleSize ,
83
85
List <int []> neighborhoods
84
86
) throws IOException {
85
87
boolean changed = false ;
@@ -90,17 +92,18 @@ private boolean stepLloyd(
90
92
Arrays .fill (nextCentroid , 0.0f );
91
93
}
92
94
93
- for (int i = 0 ; i < sampleSize ; i ++) {
94
- float [] vector = vectors .vectorValue (i );
95
- final int assignment = assignments [i ];
95
+ for (int idx = 0 ; idx < vectors .size (); idx ++) {
96
+ float [] vector = vectors .vectorValue (idx );
97
+ int vectorOrd = translateOrd .apply (idx );
98
+ final int assignment = assignments [vectorOrd ];
96
99
final int bestCentroidOffset ;
97
100
if (neighborhoods != null ) {
98
101
bestCentroidOffset = getBestCentroidFromNeighbours (centroids , vector , assignment , neighborhoods .get (assignment ));
99
102
} else {
100
103
bestCentroidOffset = getBestCentroid (centroids , vector );
101
104
}
102
105
if (assignment != bestCentroidOffset ) {
103
- assignments [i ] = bestCentroidOffset ;
106
+ assignments [vectorOrd ] = bestCentroidOffset ;
104
107
changed = true ;
105
108
}
106
109
centroidCounts [bestCentroidOffset ]++;
@@ -121,7 +124,7 @@ private boolean stepLloyd(
121
124
return changed ;
122
125
}
123
126
124
- int getBestCentroidFromNeighbours (float [][] centroids , float [] vector , int centroidIdx , int [] centroidOffsets ) {
127
+ private static int getBestCentroidFromNeighbours (float [][] centroids , float [] vector , int centroidIdx , int [] centroidOffsets ) {
125
128
int bestCentroidOffset = centroidIdx ;
126
129
assert centroidIdx >= 0 && centroidIdx < centroids .length ;
127
130
float minDsq = VectorUtil .squareDistance (vector , centroids [centroidIdx ]);
@@ -135,7 +138,7 @@ int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centr
135
138
return bestCentroidOffset ;
136
139
}
137
140
138
- int getBestCentroid (float [][] centroids , float [] vector ) {
141
+ private static int getBestCentroid (float [][] centroids , float [] vector ) {
139
142
int bestCentroidOffset = 0 ;
140
143
float minDsq = Float .MAX_VALUE ;
141
144
for (int i = 0 ; i < centroids .length ; i ++) {
@@ -281,24 +284,34 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
281
284
}
282
285
}
283
286
284
- void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate , List <int []> neighborhoods ) throws IOException {
287
+ private void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate , List <int []> neighborhoods ) throws IOException {
285
288
float [][] centroids = kMeansIntermediate .centroids ();
286
289
int k = centroids .length ;
287
290
int n = vectors .size ();
288
291
289
292
if (k == 1 || k >= n ) {
290
293
return ;
291
294
}
292
-
295
+ IntToIntFunction translateOrd = i -> i ;
296
+ FloatVectorValues sampledVectors = vectors ;
297
+ if (sampleSize < n ) {
298
+ sampledVectors = SampleReader .createSampleReader (vectors , sampleSize , 42L );
299
+ translateOrd = sampledVectors ::ordToDoc ;
300
+ }
293
301
int [] assignments = kMeansIntermediate .assignments ();
294
302
assert assignments .length == n ;
295
303
float [][] nextCentroids = new float [centroids .length ][vectors .dimension ()];
296
304
for (int i = 0 ; i < maxIterations ; i ++) {
297
- if (stepLloyd (vectors , centroids , nextCentroids , assignments , sampleSize , neighborhoods ) == false ) {
305
+ // This is potentially sampled, so we need to translate ordinals
306
+ if (stepLloyd (sampledVectors , translateOrd , centroids , nextCentroids , assignments , neighborhoods ) == false ) {
298
307
break ;
299
308
}
300
309
}
301
- stepLloyd (vectors , centroids , nextCentroids , assignments , vectors .size (), neighborhoods );
310
+ // If we were sampled, do a once over the full set of vectors to finalize the centroids
311
+ if (sampleSize < n ) {
312
+ // No ordinal translation needed here, we are using the full set of vectors
313
+ stepLloyd (vectors , i -> i , centroids , nextCentroids , assignments , neighborhoods );
314
+ }
302
315
}
303
316
304
317
/**
0 commit comments