diff --git a/modules/nextflow/src/main/groovy/nextflow/extension/CombineOp.groovy b/modules/nextflow/src/main/groovy/nextflow/extension/CombineOp.groovy index 8ef6765faf..f93f3cab27 100644 --- a/modules/nextflow/src/main/groovy/nextflow/extension/CombineOp.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/extension/CombineOp.groovy @@ -48,6 +48,8 @@ class CombineOp { private Map rightValues = [:] + private Map originalKeyMap = [:] + private static final int LEFT = 0 private static final int RIGHT = 1 @@ -93,7 +95,7 @@ class CombineOp { opts.onNext = { if( pivot ) { def pair = makeKey(pivot, it) - emit(target, index, pair.keys, pair.values) + emit(target, index, pair) } else { emit(target, index, NONE, it) @@ -120,7 +122,36 @@ class CombineOp { } @PackageScope - synchronized void emit( DataflowWriteChannel target, int index, List p, v ) { + synchronized void emit( DataflowWriteChannel target, int index, KeyPair pair ) { + emit(target, index, pair.originalKeys, pair.keys, pair.values) + } + + @PackageScope + synchronized void emit( DataflowWriteChannel target, int index, List originalKeys, List keys, v ) { + def p = keys // Use normalized keys for matching + + // Store/update the mapping from normalized key to original key + // Prefer GroupKey over plain keys + def existingOriginal = originalKeyMap.get(p) + if (existingOriginal == null) { + originalKeyMap[p] = originalKeys + } else { + // Check if any of the new original keys is a GroupKey + for (int i = 0; i < originalKeys.size(); i++) { + def newKey = originalKeys[i] + def oldKey = existingOriginal instanceof List ? existingOriginal[i] : existingOriginal + if (newKey instanceof GroupKey && !(oldKey instanceof GroupKey)) { + originalKeyMap[p] = originalKeys + break + } + } + } + + // Use the best available original key (preferring GroupKey) + def bestOriginalKeys = originalKeyMap[p] + + // Ensure bestOriginalKeys is a List for the tuple method + def bestKeysList = bestOriginalKeys instanceof List ? bestOriginalKeys : [bestOriginalKeys] if( leftValues[p] == null ) leftValues[p] = [] if( rightValues[p] == null ) rightValues[p] = [] @@ -128,7 +159,7 @@ class CombineOp { if( index == LEFT ) { log.trace "combine >> left >> by=$p; val=$v; right-values: ${rightValues[p]}" for ( Object x : rightValues[p] ) { - target.bind( tuple(p, v, x) ) + target.bind( tuple(bestKeysList, v, x) ) // Use best original keys in output } leftValues[p].add(v) return @@ -137,7 +168,7 @@ class CombineOp { if( index == RIGHT ) { log.trace "combine >> right >> by=$p; val=$v; right-values: ${leftValues[p]}" for ( Object x : leftValues[p] ) { - target.bind( tuple(p, x, v) ) + target.bind( tuple(bestKeysList, x, v) ) // Use best original keys in output } rightValues[p].add(v) return @@ -146,6 +177,12 @@ class CombineOp { throw new IllegalArgumentException("Not a valid spread operator index: $index") } + @PackageScope + synchronized void emit( DataflowWriteChannel target, int index, List p, v ) { + // Legacy method for when pivot is NONE + emit(target, index, p, p, v) + } + DataflowWriteChannel apply() { target = CH.create() diff --git a/modules/nextflow/src/main/groovy/nextflow/extension/DataflowHelper.groovy b/modules/nextflow/src/main/groovy/nextflow/extension/DataflowHelper.groovy index 14c775b04e..ab875853b5 100644 --- a/modules/nextflow/src/main/groovy/nextflow/extension/DataflowHelper.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/extension/DataflowHelper.groovy @@ -360,14 +360,11 @@ class DataflowHelper { if( !(entry instanceof List) ) { if( pivot != [0] ) throw new IllegalArgumentException("Not a valid `by` index: $pivot") - result.keys = [entry] - result.values = [] + result.addKey(entry) return result } def list = (List)entry - result.keys = new ArrayList(pivot.size()) - result.values = new ArrayList(list.size()) for( int i=0; i)itr.next() def list = entry.getValue() - addToList(result, list[0]) + def keyPair = list[0] as KeyPair + addToList(result, keyPair.values) list.remove(0) if( list.size() == 0 ) { @@ -221,6 +225,7 @@ class JoinOp { return result } + private final void checkRemainder(Map> buffers, int count, DataflowWriteChannel target ) { log.trace "Operator `join` remainder buffer: ${-> buffers}" @@ -231,17 +236,24 @@ class JoinOp { boolean fill=false def result = new ArrayList(count+1) - addToList(result, key) + + // Find the best original key from available channels + def bestOriginalKey = findBestOriginalKeys(entry) + + // Use the best available original key, or fall back to the map key + def originalKey = bestOriginalKey ?: key + addToList(result, originalKey) for( int i=0; i channelItems) { + def bestOriginalKeys = null + + for (Map.Entry entry : channelItems.entrySet()) { + def items = entry.getValue() + if (items && items.size() > 0) { + def keyPair = items[0] as KeyPair + if (bestOriginalKeys == null) { + bestOriginalKeys = keyPair.originalKeys + } else { + // Check if this channel has a GroupKey version + for (int i = 0; i < keyPair.originalKeys.size(); i++) { + def candidateKey = keyPair.originalKeys[i] + def currentKey = bestOriginalKeys instanceof List ? bestOriginalKeys[i] : bestOriginalKeys + if (candidateKey instanceof GroupKey && !(currentKey instanceof GroupKey)) { + bestOriginalKeys = keyPair.originalKeys + break + } + } + } + } + } + + return bestOriginalKeys + } + protected void checkForDuplicate( key, value, int dir, boolean add ) { if( failOnDuplicate && ( (add && !uniqueKeys.add(key)) || (!add && uniqueKeys.contains(key)) ) ) { final msg = "Detected join operation duplicate emission on ${dir==0 ? 'left' : 'right'} channel -- offending element: key=${csv0(key,',')}; value=${csv0(value,',')}" @@ -277,7 +321,17 @@ class JoinOp { result[key] = [] for( Map.Entry entry : remainder0 ) { - result[key].add(csv0(entry.value,',')) + def items = entry.value + def values = [] + // Extract values from KeyPair objects + items.each { item -> + if (item instanceof KeyPair) { + values.add(item.values) + } else { + values.add(item) + } + } + result[key].add(csv0(values,',')) } } diff --git a/modules/nextflow/src/main/groovy/nextflow/extension/KeyPair.groovy b/modules/nextflow/src/main/groovy/nextflow/extension/KeyPair.groovy index e794760b82..d2df6659c5 100644 --- a/modules/nextflow/src/main/groovy/nextflow/extension/KeyPair.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/extension/KeyPair.groovy @@ -19,6 +19,7 @@ package nextflow.extension import groovy.transform.CompileStatic import groovy.transform.EqualsAndHashCode import groovy.transform.ToString +import nextflow.extension.GroupKey /** * Implements an helper key-value helper object used in dataflow operators @@ -30,9 +31,17 @@ import groovy.transform.ToString @EqualsAndHashCode class KeyPair { List keys + List originalKeys List values + KeyPair() { + this.keys = [] + this.originalKeys = [] + this.values = [] + } + void addKey(el) { + originalKeys.add(el) keys.add(safeStr(el)) } @@ -41,6 +50,10 @@ class KeyPair { } static private safeStr(key) { - key instanceof GString ? key.toString() : key + if (key instanceof GString) + return key.toString() + if (key instanceof GroupKey) + return key.getGroupTarget() + return key } } diff --git a/modules/nextflow/src/test/groovy/nextflow/extension/JoinOpGroupKeyTest.groovy b/modules/nextflow/src/test/groovy/nextflow/extension/JoinOpGroupKeyTest.groovy new file mode 100644 index 0000000000..9024820fa1 --- /dev/null +++ b/modules/nextflow/src/test/groovy/nextflow/extension/JoinOpGroupKeyTest.groovy @@ -0,0 +1,113 @@ +/* + * Copyright 2013-2024, Seqera Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nextflow.extension + +import nextflow.Channel +import nextflow.Session +import spock.lang.Specification + +/** + * Test GroupKey preservation in join operations + * + * @author Your Name + */ +class JoinOpGroupKeyTest extends Specification { + + def setup() { + new Session() + } + + def 'should preserve GroupKey when joining channels' () { + given: + def key1 = new GroupKey('X', 2) + def key2 = new GroupKey('Y', 3) + + def ch1 = Channel.of([key1, 1], [key2, 2]) + def ch2 = Channel.of(['X', 'a'], ['Y', 'b']) + + when: + def op = new JoinOp(ch1, ch2) + def result = op.apply().toList().getVal() + + then: + result.size() == 2 + + // Check that GroupKey is preserved in the output + result.each { tuple -> + assert tuple[0] instanceof GroupKey + assert tuple.size() == 3 + } + + // Verify the actual values + def sorted = result.sort { it[0].toString() } + sorted[0][0].toString() == 'X' + sorted[0][0].groupSize == 2 + sorted[0][1] == 1 + sorted[0][2] == 'a' + + sorted[1][0].toString() == 'Y' + sorted[1][0].groupSize == 3 + sorted[1][1] == 2 + sorted[1][2] == 'b' + } + + def 'should preserve GroupKey when GroupKey is on right channel' () { + given: + def key1 = new GroupKey('X', 2) + def key2 = new GroupKey('Y', 3) + + def ch1 = Channel.of(['X', 'a'], ['Y', 'b']) + def ch2 = Channel.of([key1, 1], [key2, 2]) + + when: + def op = new JoinOp(ch1, ch2) + def result = op.apply().toList().getVal() + + then: + result.size() == 2 + + // Check that GroupKey is preserved in the output + result.each { tuple -> + assert tuple[0] instanceof GroupKey + assert tuple.size() == 3 + } + } + + def 'should handle mix of GroupKey and plain keys correctly' () { + given: + def key1 = new GroupKey('X', 2) + + def ch1 = Channel.of([key1, 1], ['Y', 2]) // Mix of GroupKey and plain key + def ch2 = Channel.of(['X', 'a'], ['Y', 'b']) + + when: + def op = new JoinOp(ch1, ch2) + def result = op.apply().toList().getVal().sort { it[0].toString() } + + then: + result.size() == 2 + + // First tuple should have GroupKey + result[0][0] instanceof GroupKey + result[0][0].toString() == 'X' + result[0][0].groupSize == 2 + + // Second tuple should have plain string + result[1][0] == 'Y' + !(result[1][0] instanceof GroupKey) + } +} \ No newline at end of file