@@ -58,6 +58,8 @@ class JoinOp {
58
58
59
59
private Set uniqueKeys = new LinkedHashSet ()
60
60
61
+ private Map<Object ,Object > originalKeyMap = new HashMap ()
62
+
61
63
JoinOp ( DataflowReadChannel source , DataflowReadChannel target , Map params = null ) {
62
64
CheckHelper . checkParams(' join' , params, JOIN_PARAMS )
63
65
this . source = source
@@ -173,6 +175,24 @@ class JoinOp {
173
175
// get the index key for this object
174
176
final item0 = DataflowHelper . makeKey(pivot, data)
175
177
178
+ // Store the mapping from normalized key to original key
179
+ // Prefer GroupKey over plain keys
180
+ def existingOriginal = originalKeyMap. get(item0. keys)
181
+ if (existingOriginal == null ) {
182
+ originalKeyMap[item0. keys] = item0. originalKeys
183
+ } else {
184
+ // Check if any of the new original keys is a GroupKey
185
+ // If so, replace the existing mapping
186
+ for (int i = 0 ; i < item0. originalKeys. size(); i++ ) {
187
+ def newKey = item0. originalKeys[i]
188
+ def oldKey = existingOriginal instanceof List ? existingOriginal[i] : existingOriginal
189
+ if (newKey instanceof GroupKey && ! (oldKey instanceof GroupKey )) {
190
+ originalKeyMap[item0. keys] = item0. originalKeys
191
+ break
192
+ }
193
+ }
194
+ }
195
+
176
196
// check for unique keys
177
197
checkForDuplicate(item0. keys, item0. values, index, false )
178
198
@@ -190,8 +210,8 @@ class JoinOp {
190
210
def entries = channels[index]
191
211
192
212
// add the received item to the list
193
- // when it is used in the gather op add always as the first item
194
- entries << item0. values
213
+ // Store the full KeyPair to preserve original keys
214
+ entries << item0
195
215
setSingleton(index, item0. values. size()== 0 )
196
216
197
217
// now check if it has received an element matching for each channel
@@ -200,15 +220,39 @@ class JoinOp {
200
220
}
201
221
202
222
def result = []
223
+
224
+ // Find the best key (prefer GroupKey) from all channels
225
+ def bestOriginalKeys = null
226
+ for (Map.Entry < Integer ,List > entry : channels. entrySet()) {
227
+ def channelItems = entry. getValue()
228
+ if (channelItems && channelItems. size() > 0 ) {
229
+ def keyPair = channelItems[0 ] as KeyPair
230
+ if (bestOriginalKeys == null ) {
231
+ bestOriginalKeys = keyPair. originalKeys
232
+ } else {
233
+ // Check if this channel has a GroupKey version
234
+ for (int i = 0 ; i < keyPair. originalKeys. size(); i++ ) {
235
+ def candidateKey = keyPair. originalKeys[i]
236
+ def currentKey = bestOriginalKeys instanceof List ? bestOriginalKeys[i] : bestOriginalKeys
237
+ if (candidateKey instanceof GroupKey && ! (currentKey instanceof GroupKey )) {
238
+ bestOriginalKeys = keyPair. originalKeys
239
+ break
240
+ }
241
+ }
242
+ }
243
+ }
244
+ }
245
+
203
246
// add the key
204
- addToList(result, item0. keys )
247
+ addToList(result, bestOriginalKeys ?: item0. originalKeys )
205
248
206
249
final itr = channels. iterator()
207
250
while ( itr. hasNext() ) {
208
251
def entry = (Map.Entry < Integer ,List > )itr. next()
209
252
210
253
def list = entry. getValue()
211
- addToList(result, list[0 ])
254
+ def keyPair = list[0 ] as KeyPair
255
+ addToList(result, keyPair. values)
212
256
213
257
list. remove(0 )
214
258
if ( list. size() == 0 ) {
@@ -221,6 +265,17 @@ class JoinOp {
221
265
return result
222
266
}
223
267
268
+ // Helper method to retrieve original data from buffer
269
+ private def getOriginalDataFromBuffer (Map<Object ,Map<Integer ,List > > buffer , Object key , int channelIndex ) {
270
+ def channels = buffer. get(key)
271
+ if (channels == null ) return null
272
+ def items = channels. get(channelIndex)
273
+ if (items == null || items. isEmpty()) return null
274
+ // Need to reconstruct the original data from the values and the key
275
+ // This is a simplified version - in reality we'd need to track the full original items
276
+ return null // For now, we'll use a different approach
277
+ }
278
+
224
279
private final void checkRemainder (Map<Object ,Map<Integer ,List > > buffers , int count , DataflowWriteChannel target ) {
225
280
log. trace " Operator `join` remainder buffer: ${ -> buffers} "
226
281
@@ -231,17 +286,43 @@ class JoinOp {
231
286
232
287
boolean fill= false
233
288
def result = new ArrayList (count+1 )
234
- addToList(result, key)
289
+
290
+ // Find the best original key from available channels
291
+ def bestOriginalKey = null
292
+ for ( int i= 0 ; i< count; i++ ) {
293
+ List items = entry[i]
294
+ if ( items && items. size() > 0 ) {
295
+ def keyPair = items[0 ] as KeyPair
296
+ if (bestOriginalKey == null ) {
297
+ bestOriginalKey = keyPair. originalKeys
298
+ } else {
299
+ // Check if this channel has a GroupKey version
300
+ for (int j = 0 ; j < keyPair. originalKeys. size(); j++ ) {
301
+ def candidateKey = keyPair. originalKeys[j]
302
+ def currentKey = bestOriginalKey instanceof List ? bestOriginalKey[j] : bestOriginalKey
303
+ if (candidateKey instanceof GroupKey && ! (currentKey instanceof GroupKey )) {
304
+ bestOriginalKey = keyPair. originalKeys
305
+ break
306
+ }
307
+ }
308
+ }
309
+ }
310
+ }
311
+
312
+ // Use the best available original key, or fall back to the map key
313
+ def originalKey = bestOriginalKey ?: originalKeyMap. get(key) ?: key
314
+ addToList(result, originalKey)
235
315
236
316
for ( int i= 0 ; i< count; i++ ) {
237
- List values = entry[i]
238
- if ( values ) {
317
+ List items = entry[i]
318
+ if ( items ) {
319
+ def keyPair = items[0 ] as KeyPair
239
320
// track unique keys
240
- checkForDuplicate(key, values[ 0 ] , i,false )
321
+ checkForDuplicate(key, keyPair . values, i, false )
241
322
242
- addToList(result, values[ 0 ] )
323
+ addToList(result, keyPair . values)
243
324
fill |= true
244
- values . remove(0 )
325
+ items . remove(0 )
245
326
}
246
327
else if ( ! singleton(i) ) {
247
328
addToList(result, null )
@@ -277,7 +358,17 @@ class JoinOp {
277
358
278
359
result[key] = []
279
360
for ( Map.Entry entry : remainder0 ) {
280
- result[key]. add(csv0(entry. value,' ,' ))
361
+ def items = entry. value
362
+ def values = []
363
+ // Extract values from KeyPair objects
364
+ items. each { item ->
365
+ if (item instanceof KeyPair ) {
366
+ values. add(item. values)
367
+ } else {
368
+ values. add(item)
369
+ }
370
+ }
371
+ result[key]. add(csv0(values,' ,' ))
281
372
}
282
373
}
283
374
0 commit comments