Skip to content

Ensure that groupKey objects are treated as their groupTarget for the purposes of joins and combines #6139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class CombineOp {

private Map<Object,List> rightValues = [:]

private Map<Object,Object> originalKeyMap = [:]

private static final int LEFT = 0

private static final int RIGHT = 1
Expand Down Expand Up @@ -93,7 +95,7 @@ class CombineOp {
opts.onNext = {
if( pivot ) {
def pair = makeKey(pivot, it)
emit(target, index, pair.keys, pair.values)
emit(target, index, pair)
}
else {
emit(target, index, NONE, it)
Expand All @@ -120,15 +122,44 @@ class CombineOp {
}

@PackageScope
synchronized void emit( DataflowWriteChannel target, int index, List p, v ) {
synchronized void emit( DataflowWriteChannel target, int index, KeyPair pair ) {
emit(target, index, pair.originalKeys, pair.keys, pair.values)
}

@PackageScope
synchronized void emit( DataflowWriteChannel target, int index, List originalKeys, List keys, v ) {
def p = keys // Use normalized keys for matching

// Store/update the mapping from normalized key to original key
// Prefer GroupKey over plain keys
def existingOriginal = originalKeyMap.get(p)
if (existingOriginal == null) {
originalKeyMap[p] = originalKeys
} else {
// Check if any of the new original keys is a GroupKey
for (int i = 0; i < originalKeys.size(); i++) {
def newKey = originalKeys[i]
def oldKey = existingOriginal instanceof List ? existingOriginal[i] : existingOriginal
if (newKey instanceof GroupKey && !(oldKey instanceof GroupKey)) {
originalKeyMap[p] = originalKeys
break
}
}
}

// Use the best available original key (preferring GroupKey)
def bestOriginalKeys = originalKeyMap[p]

// Ensure bestOriginalKeys is a List for the tuple method
def bestKeysList = bestOriginalKeys instanceof List ? bestOriginalKeys : [bestOriginalKeys]

if( leftValues[p] == null ) leftValues[p] = []
if( rightValues[p] == null ) rightValues[p] = []

if( index == LEFT ) {
log.trace "combine >> left >> by=$p; val=$v; right-values: ${rightValues[p]}"
for ( Object x : rightValues[p] ) {
target.bind( tuple(p, v, x) )
target.bind( tuple(bestKeysList, v, x) ) // Use best original keys in output
}
leftValues[p].add(v)
return
Expand All @@ -137,7 +168,7 @@ class CombineOp {
if( index == RIGHT ) {
log.trace "combine >> right >> by=$p; val=$v; right-values: ${leftValues[p]}"
for ( Object x : leftValues[p] ) {
target.bind( tuple(p, x, v) )
target.bind( tuple(bestKeysList, x, v) ) // Use best original keys in output
}
rightValues[p].add(v)
return
Expand All @@ -146,6 +177,12 @@ class CombineOp {
throw new IllegalArgumentException("Not a valid spread operator index: $index")
}

@PackageScope
synchronized void emit( DataflowWriteChannel target, int index, List p, v ) {
// Legacy method for when pivot is NONE
emit(target, index, p, p, v)
}

DataflowWriteChannel apply() {

target = CH.create()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,14 +360,11 @@ class DataflowHelper {
if( !(entry instanceof List) ) {
if( pivot != [0] )
throw new IllegalArgumentException("Not a valid `by` index: $pivot")
result.keys = [entry]
result.values = []
result.addKey(entry)
return result
}

def list = (List)entry
result.keys = new ArrayList(pivot.size())
result.values = new ArrayList(list.size())

for( int i=0; i<list.size(); i++ ) {
if( i in pivot )
Expand Down
76 changes: 65 additions & 11 deletions modules/nextflow/src/main/groovy/nextflow/extension/JoinOp.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ class JoinOp {
def entries = channels[index]

// add the received item to the list
// when it is used in the gather op add always as the first item
entries << item0.values
entries << item0
setSingleton(index, item0.values.size()==0)

// now check if it has received an element matching for each channel
Expand All @@ -200,15 +199,20 @@ class JoinOp {
}

def result = []

// Find the best key (prefer GroupKey) from all channels
def bestOriginalKeys = findBestOriginalKeys(channels)

// add the key
addToList(result, item0.keys)
addToList(result, bestOriginalKeys ?: item0.originalKeys)

final itr = channels.iterator()
while( itr.hasNext() ) {
def entry = (Map.Entry<Integer,List>)itr.next()

def list = entry.getValue()
addToList(result, list[0])
def keyPair = list[0] as KeyPair
addToList(result, keyPair.values)

list.remove(0)
if( list.size() == 0 ) {
Expand All @@ -221,6 +225,7 @@ class JoinOp {
return result
}


private final void checkRemainder(Map<Object,Map<Integer,List>> buffers, int count, DataflowWriteChannel target ) {
log.trace "Operator `join` remainder buffer: ${-> buffers}"

Expand All @@ -231,17 +236,24 @@ class JoinOp {

boolean fill=false
def result = new ArrayList(count+1)
addToList(result, key)

// Find the best original key from available channels
def bestOriginalKey = findBestOriginalKeys(entry)

// Use the best available original key, or fall back to the map key
def originalKey = bestOriginalKey ?: key
addToList(result, originalKey)

for( int i=0; i<count; i++ ) {
List values = entry[i]
if( values ) {
List items = entry[i]
if( items ) {
def keyPair = items[0] as KeyPair
// track unique keys
checkForDuplicate(key, values[0], i,false)
checkForDuplicate(key, keyPair.values, i, false)

addToList(result, values[0])
addToList(result, keyPair.values)
fill |= true
values.remove(0)
items.remove(0)
}
else if( !singleton(i) ) {
addToList(result, null)
Expand All @@ -260,6 +272,38 @@ class JoinOp {
}
}

/**
* Finds the best original keys from a map of channel items, preferring GroupKey over plain keys
*
* @param channelItems Map of channel index to list of items (KeyPair objects)
* @return The best original keys found, or null if no items available
*/
private def findBestOriginalKeys(Map<Integer,List> channelItems) {
def bestOriginalKeys = null

for (Map.Entry<Integer,List> entry : channelItems.entrySet()) {
def items = entry.getValue()
if (items && items.size() > 0) {
def keyPair = items[0] as KeyPair
if (bestOriginalKeys == null) {
bestOriginalKeys = keyPair.originalKeys
} else {
// Check if this channel has a GroupKey version
for (int i = 0; i < keyPair.originalKeys.size(); i++) {
def candidateKey = keyPair.originalKeys[i]
def currentKey = bestOriginalKeys instanceof List ? bestOriginalKeys[i] : bestOriginalKeys
if (candidateKey instanceof GroupKey && !(currentKey instanceof GroupKey)) {
bestOriginalKeys = keyPair.originalKeys
break
}
}
}
}
}

return bestOriginalKeys
}

protected void checkForDuplicate( key, value, int dir, boolean add ) {
if( failOnDuplicate && ( (add && !uniqueKeys.add(key)) || (!add && uniqueKeys.contains(key)) ) ) {
final msg = "Detected join operation duplicate emission on ${dir==0 ? 'left' : 'right'} channel -- offending element: key=${csv0(key,',')}; value=${csv0(value,',')}"
Expand All @@ -277,7 +321,17 @@ class JoinOp {

result[key] = []
for( Map.Entry entry : remainder0 ) {
result[key].add(csv0(entry.value,','))
def items = entry.value
def values = []
// Extract values from KeyPair objects
items.each { item ->
if (item instanceof KeyPair) {
values.add(item.values)
} else {
values.add(item)
}
}
result[key].add(csv0(values,','))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package nextflow.extension
import groovy.transform.CompileStatic
import groovy.transform.EqualsAndHashCode
import groovy.transform.ToString
import nextflow.extension.GroupKey

/**
* Implements an helper key-value helper object used in dataflow operators
Expand All @@ -30,9 +31,17 @@ import groovy.transform.ToString
@EqualsAndHashCode
class KeyPair {
List keys
List originalKeys
List values

KeyPair() {
this.keys = []
this.originalKeys = []
this.values = []
}

void addKey(el) {
originalKeys.add(el)
keys.add(safeStr(el))
}

Expand All @@ -41,6 +50,10 @@ class KeyPair {
}

static private safeStr(key) {
key instanceof GString ? key.toString() : key
if (key instanceof GString)
return key.toString()
if (key instanceof GroupKey)
return key.getGroupTarget()
return key
Comment on lines 52 to +57
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds very reasonable, it may be worth adding a unit test

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Will fix that today.

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright 2013-2024, Seqera Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package nextflow.extension

import nextflow.Channel
import nextflow.Session
import spock.lang.Specification

/**
* Test GroupKey preservation in join operations
*
* @author Your Name
*/
class JoinOpGroupKeyTest extends Specification {

def setup() {
new Session()
}

def 'should preserve GroupKey when joining channels' () {
given:
def key1 = new GroupKey('X', 2)
def key2 = new GroupKey('Y', 3)

def ch1 = Channel.of([key1, 1], [key2, 2])
def ch2 = Channel.of(['X', 'a'], ['Y', 'b'])

when:
def op = new JoinOp(ch1, ch2)
def result = op.apply().toList().getVal()

then:
result.size() == 2

// Check that GroupKey is preserved in the output
result.each { tuple ->
assert tuple[0] instanceof GroupKey
assert tuple.size() == 3
}

// Verify the actual values
def sorted = result.sort { it[0].toString() }
sorted[0][0].toString() == 'X'
sorted[0][0].groupSize == 2
sorted[0][1] == 1
sorted[0][2] == 'a'

sorted[1][0].toString() == 'Y'
sorted[1][0].groupSize == 3
sorted[1][1] == 2
sorted[1][2] == 'b'
}

def 'should preserve GroupKey when GroupKey is on right channel' () {
given:
def key1 = new GroupKey('X', 2)
def key2 = new GroupKey('Y', 3)

def ch1 = Channel.of(['X', 'a'], ['Y', 'b'])
def ch2 = Channel.of([key1, 1], [key2, 2])

when:
def op = new JoinOp(ch1, ch2)
def result = op.apply().toList().getVal()

then:
result.size() == 2

// Check that GroupKey is preserved in the output
result.each { tuple ->
assert tuple[0] instanceof GroupKey
assert tuple.size() == 3
}
}

def 'should handle mix of GroupKey and plain keys correctly' () {
given:
def key1 = new GroupKey('X', 2)

def ch1 = Channel.of([key1, 1], ['Y', 2]) // Mix of GroupKey and plain key
def ch2 = Channel.of(['X', 'a'], ['Y', 'b'])

when:
def op = new JoinOp(ch1, ch2)
def result = op.apply().toList().getVal().sort { it[0].toString() }

then:
result.size() == 2

// First tuple should have GroupKey
result[0][0] instanceof GroupKey
result[0][0].toString() == 'X'
result[0][0].groupSize == 2

// Second tuple should have plain string
result[1][0] == 'Y'
!(result[1][0] instanceof GroupKey)
}
}
Loading