Skip to content

Commit d105842

Browse files
committed
Ensure GroupKey is preserved and commutative in join/combine
- Store original keys in KeyPair and use them for output emission in join/combine - Always prefer GroupKey over plain keys when both are present for the same match - Fixes timing-dependent bugs when joining or combining channels with mixed GroupKey/plain keys - Updates buffer logic to store KeyPair objects, ensuring metadata like group size is retained - Updates remainder and mismatch handling to work with new buffer structure - Adds and passes comprehensive tests for GroupKey preservation and commutativity Fixes #4104 Signed-off-by: Rob Syme <[email protected]>
1 parent 1645098 commit d105842

File tree

5 files changed

+264
-21
lines changed

5 files changed

+264
-21
lines changed

modules/nextflow/src/main/groovy/nextflow/extension/CombineOp.groovy

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class CombineOp {
4848

4949
private Map<Object,List> rightValues = [:]
5050

51+
private Map<Object,Object> originalKeyMap = [:]
52+
5153
private static final int LEFT = 0
5254

5355
private static final int RIGHT = 1
@@ -93,7 +95,7 @@ class CombineOp {
9395
opts.onNext = {
9496
if( pivot ) {
9597
def pair = makeKey(pivot, it)
96-
emit(target, index, pair.keys, pair.values)
98+
emit(target, index, pair)
9799
}
98100
else {
99101
emit(target, index, NONE, it)
@@ -120,15 +122,44 @@ class CombineOp {
120122
}
121123

122124
@PackageScope
123-
synchronized void emit( DataflowWriteChannel target, int index, List p, v ) {
125+
synchronized void emit( DataflowWriteChannel target, int index, KeyPair pair ) {
126+
emit(target, index, pair.originalKeys, pair.keys, pair.values)
127+
}
128+
129+
@PackageScope
130+
synchronized void emit( DataflowWriteChannel target, int index, List originalKeys, List keys, v ) {
131+
def p = keys // Use normalized keys for matching
132+
133+
// Store/update the mapping from normalized key to original key
134+
// Prefer GroupKey over plain keys
135+
def existingOriginal = originalKeyMap.get(p)
136+
if (existingOriginal == null) {
137+
originalKeyMap[p] = originalKeys
138+
} else {
139+
// Check if any of the new original keys is a GroupKey
140+
for (int i = 0; i < originalKeys.size(); i++) {
141+
def newKey = originalKeys[i]
142+
def oldKey = existingOriginal instanceof List ? existingOriginal[i] : existingOriginal
143+
if (newKey instanceof GroupKey && !(oldKey instanceof GroupKey)) {
144+
originalKeyMap[p] = originalKeys
145+
break
146+
}
147+
}
148+
}
149+
150+
// Use the best available original key (preferring GroupKey)
151+
def bestOriginalKeys = originalKeyMap[p]
152+
153+
// Ensure bestOriginalKeys is a List for the tuple method
154+
def bestKeysList = bestOriginalKeys instanceof List ? bestOriginalKeys : [bestOriginalKeys]
124155

125156
if( leftValues[p] == null ) leftValues[p] = []
126157
if( rightValues[p] == null ) rightValues[p] = []
127158

128159
if( index == LEFT ) {
129160
log.trace "combine >> left >> by=$p; val=$v; right-values: ${rightValues[p]}"
130161
for ( Object x : rightValues[p] ) {
131-
target.bind( tuple(p, v, x) )
162+
target.bind( tuple(bestKeysList, v, x) ) // Use best original keys in output
132163
}
133164
leftValues[p].add(v)
134165
return
@@ -137,7 +168,7 @@ class CombineOp {
137168
if( index == RIGHT ) {
138169
log.trace "combine >> right >> by=$p; val=$v; right-values: ${leftValues[p]}"
139170
for ( Object x : leftValues[p] ) {
140-
target.bind( tuple(p, x, v) )
171+
target.bind( tuple(bestKeysList, x, v) ) // Use best original keys in output
141172
}
142173
rightValues[p].add(v)
143174
return
@@ -146,6 +177,12 @@ class CombineOp {
146177
throw new IllegalArgumentException("Not a valid spread operator index: $index")
147178
}
148179

180+
@PackageScope
181+
synchronized void emit( DataflowWriteChannel target, int index, List p, v ) {
182+
// Legacy method for when pivot is NONE
183+
emit(target, index, p, p, v)
184+
}
185+
149186
DataflowWriteChannel apply() {
150187

151188
target = CH.create()

modules/nextflow/src/main/groovy/nextflow/extension/DataflowHelper.groovy

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,13 +361,10 @@ class DataflowHelper {
361361
if( pivot != [0] )
362362
throw new IllegalArgumentException("Not a valid `by` index: $pivot")
363363
result.addKey(entry)
364-
result.values = []
365364
return result
366365
}
367366

368367
def list = (List)entry
369-
result.keys = new ArrayList(pivot.size())
370-
result.values = new ArrayList(list.size())
371368

372369
for( int i=0; i<list.size(); i++ ) {
373370
if( i in pivot )

modules/nextflow/src/main/groovy/nextflow/extension/JoinOp.groovy

Lines changed: 102 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class JoinOp {
5858

5959
private Set uniqueKeys = new LinkedHashSet()
6060

61+
private Map<Object,Object> originalKeyMap = new HashMap()
62+
6163
JoinOp( DataflowReadChannel source, DataflowReadChannel target, Map params = null ) {
6264
CheckHelper.checkParams('join', params, JOIN_PARAMS)
6365
this.source = source
@@ -173,6 +175,24 @@ class JoinOp {
173175
// get the index key for this object
174176
final item0 = DataflowHelper.makeKey(pivot, data)
175177

178+
// Store the mapping from normalized key to original key
179+
// Prefer GroupKey over plain keys
180+
def existingOriginal = originalKeyMap.get(item0.keys)
181+
if (existingOriginal == null) {
182+
originalKeyMap[item0.keys] = item0.originalKeys
183+
} else {
184+
// Check if any of the new original keys is a GroupKey
185+
// If so, replace the existing mapping
186+
for (int i = 0; i < item0.originalKeys.size(); i++) {
187+
def newKey = item0.originalKeys[i]
188+
def oldKey = existingOriginal instanceof List ? existingOriginal[i] : existingOriginal
189+
if (newKey instanceof GroupKey && !(oldKey instanceof GroupKey)) {
190+
originalKeyMap[item0.keys] = item0.originalKeys
191+
break
192+
}
193+
}
194+
}
195+
176196
// check for unique keys
177197
checkForDuplicate(item0.keys, item0.values, index, false)
178198

@@ -190,8 +210,8 @@ class JoinOp {
190210
def entries = channels[index]
191211

192212
// add the received item to the list
193-
// when it is used in the gather op add always as the first item
194-
entries << item0.values
213+
// Store the full KeyPair to preserve original keys
214+
entries << item0
195215
setSingleton(index, item0.values.size()==0)
196216

197217
// now check if it has received an element matching for each channel
@@ -200,15 +220,39 @@ class JoinOp {
200220
}
201221

202222
def result = []
223+
224+
// Find the best key (prefer GroupKey) from all channels
225+
def bestOriginalKeys = null
226+
for (Map.Entry<Integer,List> entry : channels.entrySet()) {
227+
def channelItems = entry.getValue()
228+
if (channelItems && channelItems.size() > 0) {
229+
def keyPair = channelItems[0] as KeyPair
230+
if (bestOriginalKeys == null) {
231+
bestOriginalKeys = keyPair.originalKeys
232+
} else {
233+
// Check if this channel has a GroupKey version
234+
for (int i = 0; i < keyPair.originalKeys.size(); i++) {
235+
def candidateKey = keyPair.originalKeys[i]
236+
def currentKey = bestOriginalKeys instanceof List ? bestOriginalKeys[i] : bestOriginalKeys
237+
if (candidateKey instanceof GroupKey && !(currentKey instanceof GroupKey)) {
238+
bestOriginalKeys = keyPair.originalKeys
239+
break
240+
}
241+
}
242+
}
243+
}
244+
}
245+
203246
// add the key
204-
addToList(result, item0.keys)
247+
addToList(result, bestOriginalKeys ?: item0.originalKeys)
205248

206249
final itr = channels.iterator()
207250
while( itr.hasNext() ) {
208251
def entry = (Map.Entry<Integer,List>)itr.next()
209252

210253
def list = entry.getValue()
211-
addToList(result, list[0])
254+
def keyPair = list[0] as KeyPair
255+
addToList(result, keyPair.values)
212256

213257
list.remove(0)
214258
if( list.size() == 0 ) {
@@ -221,6 +265,17 @@ class JoinOp {
221265
return result
222266
}
223267

268+
// Helper method to retrieve original data from buffer
269+
private def getOriginalDataFromBuffer(Map<Object,Map<Integer,List>> buffer, Object key, int channelIndex) {
270+
def channels = buffer.get(key)
271+
if (channels == null) return null
272+
def items = channels.get(channelIndex)
273+
if (items == null || items.isEmpty()) return null
274+
// Need to reconstruct the original data from the values and the key
275+
// This is a simplified version - in reality we'd need to track the full original items
276+
return null // For now, we'll use a different approach
277+
}
278+
224279
private final void checkRemainder(Map<Object,Map<Integer,List>> buffers, int count, DataflowWriteChannel target ) {
225280
log.trace "Operator `join` remainder buffer: ${-> buffers}"
226281

@@ -231,17 +286,43 @@ class JoinOp {
231286

232287
boolean fill=false
233288
def result = new ArrayList(count+1)
234-
addToList(result, key)
289+
290+
// Find the best original key from available channels
291+
def bestOriginalKey = null
292+
for( int i=0; i<count; i++ ) {
293+
List items = entry[i]
294+
if( items && items.size() > 0 ) {
295+
def keyPair = items[0] as KeyPair
296+
if (bestOriginalKey == null) {
297+
bestOriginalKey = keyPair.originalKeys
298+
} else {
299+
// Check if this channel has a GroupKey version
300+
for (int j = 0; j < keyPair.originalKeys.size(); j++) {
301+
def candidateKey = keyPair.originalKeys[j]
302+
def currentKey = bestOriginalKey instanceof List ? bestOriginalKey[j] : bestOriginalKey
303+
if (candidateKey instanceof GroupKey && !(currentKey instanceof GroupKey)) {
304+
bestOriginalKey = keyPair.originalKeys
305+
break
306+
}
307+
}
308+
}
309+
}
310+
}
311+
312+
// Use the best available original key, or fall back to the map key
313+
def originalKey = bestOriginalKey ?: originalKeyMap.get(key) ?: key
314+
addToList(result, originalKey)
235315

236316
for( int i=0; i<count; i++ ) {
237-
List values = entry[i]
238-
if( values ) {
317+
List items = entry[i]
318+
if( items ) {
319+
def keyPair = items[0] as KeyPair
239320
// track unique keys
240-
checkForDuplicate(key, values[0], i,false)
321+
checkForDuplicate(key, keyPair.values, i, false)
241322

242-
addToList(result, values[0])
323+
addToList(result, keyPair.values)
243324
fill |= true
244-
values.remove(0)
325+
items.remove(0)
245326
}
246327
else if( !singleton(i) ) {
247328
addToList(result, null)
@@ -277,7 +358,17 @@ class JoinOp {
277358

278359
result[key] = []
279360
for( Map.Entry entry : remainder0 ) {
280-
result[key].add(csv0(entry.value,','))
361+
def items = entry.value
362+
def values = []
363+
// Extract values from KeyPair objects
364+
items.each { item ->
365+
if (item instanceof KeyPair) {
366+
values.add(item.values)
367+
} else {
368+
values.add(item)
369+
}
370+
}
371+
result[key].add(csv0(values,','))
281372
}
282373
}
283374

modules/nextflow/src/main/groovy/nextflow/extension/KeyPair.groovy

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,17 @@ import nextflow.extension.GroupKey
3131
@EqualsAndHashCode
3232
class KeyPair {
3333
List keys
34+
List originalKeys
3435
List values
3536

37+
KeyPair() {
38+
this.keys = []
39+
this.originalKeys = []
40+
this.values = []
41+
}
42+
3643
void addKey(el) {
37-
if (keys == null) {
38-
keys = []
39-
}
44+
originalKeys.add(el)
4045
keys.add(safeStr(el))
4146
}
4247

0 commit comments

Comments
 (0)