Skip to content

Commit be68d35

Browse files
committed
Change to a single concatenate function
1 parent af613f7 commit be68d35

File tree

4 files changed

+88
-101
lines changed

4 files changed

+88
-101
lines changed

python/tests/test_table_transforms.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,6 @@
3939
# we can remove this.
4040

4141

42-
class TestShift:
43-
# Most testing is done on the TreeSequence methods. Here we just check
44-
# that the TableCollection methods work even if they produce an invalid ts
45-
def test_too_negative(self):
46-
tables = tskit.Tree.generate_comb(2).tree_sequence.dump_tables()
47-
tables.shift(-1)
48-
assert np.min(tables.edges.left) == -1
49-
50-
def test_bad_seq_len(self):
51-
tables = tskit.Tree.generate_comb(2).tree_sequence.dump_tables()
52-
tables.shift(1, sequence_length=0.5)
53-
assert tables.sequence_length == 0.5
54-
assert np.max(tables.edges.right) == 2
55-
56-
5742
def delete_older_definition(tables, time):
5843
node_time = tables.nodes.time
5944
edges = tables.edges.copy()

python/tests/test_topology.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7091,6 +7091,7 @@ def test_shift(self, shift):
70917091
assert np.min(ts.tables.edges.left) == 1 + shift
70927092
assert np.max(ts.tables.edges.right) == 2 + shift
70937093
assert np.all(ts.tables.sites.position == 1.5 + shift)
7094+
assert len(list(ts.trees())) == ts.num_trees
70947095

70957096
def test_sequence_length(self):
70967097
ts = tskit.Tree.generate_comb(2).tree_sequence
@@ -7147,11 +7148,31 @@ def test_simple(self):
71477148
ts4 = joint_ts.delete_intervals([[0, 2]]).ltrim()
71487149
assert ts4.equals(ts2.simplify(), ignore_provenance=True)
71497150

7151+
def test_multiple(self):
7152+
np.random.seed(42)
7153+
ts3 = [
7154+
tskit.Tree.generate_comb(5, span=2).tree_sequence,
7155+
tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence,
7156+
tskit.Tree.generate_star(5, span=5).tree_sequence,
7157+
]
7158+
for i in range(1, len(ts3)):
7159+
# shuffle the sample nodes so they don't have the same IDs
7160+
ts3[i] = ts3[i].subset(np.random.permutation(ts3[i].num_nodes))
7161+
assert not np.all(ts3[0].samples() == ts3[1].samples())
7162+
assert not np.all(ts3[0].samples() == ts3[2].samples())
7163+
assert not np.all(ts3[1].samples() == ts3[2].samples())
7164+
ts = ts3[0].concatenate(*ts3[1:])
7165+
assert ts.sequence_length == sum([t.sequence_length for t in ts3])
7166+
assert ts.num_nodes - ts.num_samples == sum(
7167+
[t.num_nodes - t.num_samples for t in ts3]
7168+
)
7169+
assert np.all(ts.samples() == ts3[0].samples())
7170+
71507171
def test_empty(self):
71517172
empty_ts = tskit.TableCollection(10).tree_sequence()
7152-
ts = empty_ts.concatenate(empty_ts)
7173+
ts = empty_ts.concatenate(empty_ts, empty_ts, empty_ts)
71537174
assert ts.num_nodes == 0
7154-
assert ts.sequence_length == 20
7175+
assert ts.sequence_length == 40
71557176

71567177
def test_samples_at_end(self):
71577178
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
@@ -7182,7 +7203,7 @@ def test_some_shared_samples(self):
71827203
shared = np.full(ts2.num_nodes, tskit.NULL)
71837204
shared[0] = 1
71847205
shared[1] = 0
7185-
joint_ts = ts1.concatenate(ts2, node_mapping=shared)
7206+
joint_ts = ts1.concatenate(ts2, node_mappings=[shared])
71867207
assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length
71877208
assert joint_ts.num_samples == ts1.num_samples + ts2.num_samples - 2
71887209
assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 2

python/tskit/tables.py

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4005,6 +4005,7 @@ def shift(self, value, *, sequence_length=None, record_provenance=True):
40054005
:param sequence_length: The new sequence length of the tree sequence. If
40064006
``None`` (default) add `value` to the sequence length.
40074007
"""
4008+
self.drop_index()
40084009
self.edges.left += value
40094010
self.edges.right += value
40104011
self.sites.position += value
@@ -4023,64 +4024,6 @@ def shift(self, value, *, sequence_length=None, record_provenance=True):
40234024
record=json.dumps(provenance.get_provenance_dict(parameters))
40244025
)
40254026

4026-
def concatenate(
4027-
self, other, *, node_mapping=None, record_provenance=True, **kwargs
4028-
):
4029-
"""
4030-
Concatenate another table collection to the right of this one. This
4031-
{meth}`shift`s the other table coordinate rightwards, then calls
4032-
{meth}`union` with ``check_shared_equality=False`` and the provided
4033-
``node_mapping``. If no node mapping is given, the two table
4034-
collections must have the same number of samples, and those are treated
4035-
(in numerical order) as shared between the two table collections.
4036-
This is identical to :meth:`TreeSequence.concatenate` but
4037-
acts *in place* to alter the data in this :class:`TableCollection`.
4038-
4039-
.. note::
4040-
To add gaps between the concatenated tables, use :meth:`shift` before
4041-
concatenating; to remove gaps, use :meth:`trim`.
4042-
4043-
:param TableCollection other: The other table collection to add to the right
4044-
of this one.
4045-
:param list node_mapping: An array of integers of the same length as the number
4046-
of nodes in ``other``, where the _k_'th element gives the id of the node in
4047-
the current table collection corresponding to node _k_ in the other table
4048-
collection (see {meth}`union`). If None (default), only the sample nodes
4049-
between the two node tables, in numerical order, are mapped to each other.
4050-
:param bool record_provenance: If True (default), record details of this call to
4051-
``concatenate`` in the returned tree sequence's provenance information
4052-
(Default: True).
4053-
:param \\**kwargs: Additional keyword arguments to pass to {meth}`union`
4054-
(e.g. ``add_populations``).
4055-
"""
4056-
if node_mapping is None:
4057-
samples = np.where(self.nodes.flags & tskit.NODE_IS_SAMPLE)[0]
4058-
other_samples = np.where(other.nodes.flags & tskit.NODE_IS_SAMPLE)[0]
4059-
if len(other_samples) != len(samples):
4060-
raise ValueError(
4061-
"each `other` must have the same number of samples as `self`"
4062-
)
4063-
node_mapping = np.full(other.nodes.num_rows, tskit.NULL, dtype=np.int32)
4064-
node_mapping[other_samples] = samples
4065-
other.shift(self.sequence_length, record_provenance=False)
4066-
self.sequence_length = other.sequence_length
4067-
# NB: should we use a different default for add_populations?
4068-
self.union(
4069-
other,
4070-
node_mapping=node_mapping,
4071-
check_shared_equality=False, # Needed as checks fail with internal samples
4072-
record_provenance=False,
4073-
**kwargs,
4074-
)
4075-
if record_provenance:
4076-
parameters = {
4077-
"command": "concatenate",
4078-
"TODO": "add concatenate parameters", # tricky as both have provenances
4079-
}
4080-
self.provenances.add_row(
4081-
record=json.dumps(provenance.get_provenance_dict(parameters))
4082-
)
4083-
40844027
def delete_older(self, time):
40854028
"""
40864029
Deletes edge, mutation and migration information at least as old as

python/tskit/trees.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import functools
3333
import io
3434
import itertools
35+
import json
3536
import math
3637
import numbers
3738
import warnings
@@ -46,6 +47,7 @@
4647
import tskit.combinatorics as combinatorics
4748
import tskit.drawing as drawing
4849
import tskit.metadata as metadata_module
50+
import tskit.provenance as provenance
4951
import tskit.tables as tables
5052
import tskit.text_formats as text_formats
5153
import tskit.util as util
@@ -7091,39 +7093,75 @@ def shift(self, value, sequence_length=None, record_provenance=True):
70917093
return ts
70927094

70937095
def concatenate(
7094-
self, other, *, node_mapping=None, record_provenance=True, **kwargs
7096+
self, *args, node_mappings=None, record_provenance=True, add_populations=None
70957097
):
7096-
"""
7097-
Concatenate another tree sequence to the right of this one. This shifts the
7098-
coordinate system of the other tree sequence rightwards, then calls
7099-
{meth}`union` with the provided ``node_mapping``. If no node mapping
7100-
is given, matches sample nodes only, in numerical order.
7098+
r"""
7099+
Concatenate a set of tree sequences to the right of this one, by repeatedly
7100+
calling {meth}`union` with an (optional)
7101+
node mapping for each of the ``others``. If any node mapping is ``None``
7102+
only map the sample nodes between the input tree sequence and this one,
7103+
based on the numerical order of sample node IDs.
71017104
71027105
.. note::
71037106
To add gaps between the concatenated tables, use :meth:`shift` or
71047107
to remove gaps, use :meth:`trim` before concatenating.
71057108
7106-
:param TableCollection other: The other table collection to add to the right
7107-
of this one.
7108-
:param list node_mapping: An array of integers of the same length as the number
7109-
of nodes in ``other``, where the _k_'th element gives the id of the node in
7110-
the current table collection corresponding to node _k_ in the other table
7111-
collection (see :meth:`union`). If None (default), only the sample nodes
7112-
between the two node tables, in numerical order, are mapped to each other.
7113-
:param bool record_provenance: If True (default), record details of this call to
7114-
``concatenate`` in the returned tree sequence's provenance information
7115-
(Default: True).
7116-
:param \\**kwargs: Additional keyword arguments to pass to :meth:`union`,
7117-
e.g. ``add_populations``.
7118-
"""
7109+
:param TreeSequence \*args: A list of other tree sequences to append to
7110+
the right of this one.
7111+
:param Union[list, None] node_mappings: An list of node mappings for each
7112+
input tree sequence in ``args``. Each should either be an array of
7113+
integers of the same length as the number of nodes in the equivalent
7114+
input tree sequence (see :meth:`union` for details), or ``None``.
7115+
If ``None``, only sample nodes are mapped to each other.
7116+
Default: ``None``, treated as ``[None] * len(args)``.
7117+
:param bool record_provenance: If True (default), record details of this
7118+
call to ``concatenate`` in the returned tree sequence's provenance
7119+
information (Default: True).
7120+
:param bool add_populations: If True (default), nodes new to ``self`` will
7121+
be assigned new population IDs (see :meth:`union`)
7122+
"""
7123+
if node_mappings is None:
7124+
node_mappings = [None] * len(args)
7125+
if add_populations is None:
7126+
add_populations = True
7127+
if len(node_mappings) != len(args):
7128+
raise ValueError(
7129+
"You must provide the same number of node_mappings as args"
7130+
)
71197131

7132+
samples = self.samples()
71207133
tables = self.dump_tables()
7121-
tables.concatenate(
7122-
other.tables,
7123-
node_mapping=node_mapping,
7124-
record_provenance=record_provenance,
7125-
**kwargs,
7126-
)
7134+
tables.drop_index()
7135+
7136+
for node_mapping, other in zip(node_mappings, args):
7137+
if node_mapping is None:
7138+
other_samples = other.samples()
7139+
if len(other_samples) != len(samples):
7140+
raise ValueError(
7141+
"each `other` must have the same number of samples as `self`"
7142+
)
7143+
node_mapping = np.full(other.num_nodes, tskit.NULL, dtype=np.int32)
7144+
node_mapping[other_samples] = samples
7145+
other_tables = other.dump_tables()
7146+
other_tables.shift(tables.sequence_length, record_provenance=False)
7147+
tables.sequence_length = other_tables.sequence_length
7148+
# NB: should we use a different default for add_populations?
7149+
tables.union(
7150+
other_tables,
7151+
node_mapping=node_mapping,
7152+
check_shared_equality=False, # Else checks fail with internal samples
7153+
record_provenance=False,
7154+
add_populations=add_populations,
7155+
)
7156+
if record_provenance:
7157+
parameters = {
7158+
"command": "concatenate",
7159+
"TODO": "add concatenate parameters", # tricky as both have provenances
7160+
}
7161+
tables.provenances.add_row(
7162+
record=json.dumps(provenance.get_provenance_dict(parameters))
7163+
)
7164+
71277165
return tables.tree_sequence()
71287166

71297167
def split_edges(self, time, *, flags=None, population=None, metadata=None):

0 commit comments

Comments
 (0)