From 43c87a444e4ecd441cca3ed356c9cb175c35cd5d Mon Sep 17 00:00:00 2001 From: Henrik Andersson Date: Thu, 13 Mar 2025 10:12:41 +0100 Subject: [PATCH 1/2] New method ComparerCollection.merge --- modelskill/comparison/_collection.py | 23 ++++++++++++----------- tests/test_combine_comparers.py | 10 +++++----- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/modelskill/comparison/_collection.py b/modelskill/comparison/_collection.py index 4961f824a..a64d9ae33 100644 --- a/modelskill/comparison/_collection.py +++ b/modelskill/comparison/_collection.py @@ -89,17 +89,7 @@ class ComparerCollection(Mapping, Scoreable): plotter = ComparerCollectionPlotter def __init__(self, comparers: Iterable[Comparer]) -> None: - self._comparers: Dict[str, Comparer] = {} - - for cmp in comparers: - if cmp.name in self._comparers: - # comparer with this name already exists! - # maybe the user is trying to add a new model - # or a new time period - self._comparers[cmp.name] += cmp - else: - self._comparers[cmp.name] = cmp - + self._comparers = {cmp.name: cmp for cmp in comparers} self.plot = ComparerCollection.plotter(self) """Plot using the [](`~modelskill.comparison.ComparerCollectionPlotter`)""" @@ -186,6 +176,17 @@ def __repr__(self) -> str: out.append(f"{index}: {key} - {value.quantity}") return str.join("\n", out) + def merge(self, other: "ComparerCollection") -> "ComparerCollection": + # make a copy of self to avoid modifying the original + res = self.copy() + + for cmp in other: + if cmp.name in self._comparers: + res._comparers[cmp.name] += cmp + else: + res._comparers[cmp.name] = cmp + return res + def rename(self, mapping: Dict[str, str]) -> "ComparerCollection": """Rename observation, model or auxiliary data variables diff --git a/tests/test_combine_comparers.py b/tests/test_combine_comparers.py index 16dac97d3..f180e7112 100644 --- a/tests/test_combine_comparers.py +++ b/tests/test_combine_comparers.py @@ -59,7 +59,7 @@ def test_concat_model(o123, mrmike, mrmike2): assert cc2.mod_names[0] == cc12.mod_names[-1] assert cc2.end_time == cc12.end_time - cc12b = cc1 + cc2 + cc12b = cc1.merge(cc2) assert cc12b.score() == cc12.score() assert cc12b.n_points == cc12.n_points @@ -77,7 +77,7 @@ def test_concat_model_different_time(o123, mrmike, mr2days): assert cc2.mod_names[0] == cc12.mod_names[-1] assert cc2.end_time == cc12.end_time - cc12b = cc1 + cc2 + cc12b = cc1.merge(cc2) assert cc12b.score() == cc12.score() assert cc12b.n_points == cc12.n_points @@ -114,7 +114,7 @@ def test_concat_time_overlap(o123, mrmike): assert cc1.n_points > cc26.n_points # cc26 completely contained in cc1 - cc12 = cc1 + cc26 + cc12 = cc1.merge(cc26) assert cc1.start_time == cc12.start_time assert cc1.end_time == cc12.end_time assert cc1.n_points == cc12.n_points @@ -132,10 +132,10 @@ def test_concat_time_overlap(o123, mrmike): cc2 = ms.match([o1, o2, o3], mrmike) # cc26 _not_ completely contained in cc2 - cc12 = cc26 + cc2 + cc12 = cc26.merge(cc2) assert cc2.start_time > cc12.start_time assert cc2.end_time == cc12.end_time assert cc2.n_points < cc12.n_points - cc12a = cc2 + cc26 + cc12a = cc2.merge(cc26) assert cc12a.n_points == cc12.n_points From 22d6dceb85e20e2512277168732b58258a14046e Mon Sep 17 00:00:00 2001 From: Henrik Andersson Date: Thu, 13 Mar 2025 11:52:37 +0100 Subject: [PATCH 2/2] + should still work --- modelskill/comparison/_collection.py | 27 ++++++++++++++++++--------- tests/test_combine_comparers.py | 3 ++- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/modelskill/comparison/_collection.py b/modelskill/comparison/_collection.py index a64d9ae33..9728b8f6d 100644 --- a/modelskill/comparison/_collection.py +++ b/modelskill/comparison/_collection.py @@ -176,15 +176,23 @@ def __repr__(self) -> str: out.append(f"{index}: {key} - {value.quantity}") return str.join("\n", out) - def merge(self, other: "ComparerCollection") -> "ComparerCollection": + def merge(self, other: "ComparerCollection" | Comparer) -> "ComparerCollection": # make a copy of self to avoid modifying the original res = self.copy() - for cmp in other: - if cmp.name in self._comparers: - res._comparers[cmp.name] += cmp + if isinstance(other, Comparer): + if other.name in res._comparers: + res._comparers[other.name] += other else: - res._comparers[cmp.name] = cmp + res._comparers[other.name] = other + elif isinstance(other, ComparerCollection): + for cmp in other: + if cmp.name in self._comparers: + res._comparers[cmp.name] += cmp + else: + res._comparers[cmp.name] = cmp + else: + raise TypeError(f"Cannot merge {type(other)} with {type(self)}") return res def rename(self, mapping: Dict[str, str]) -> "ComparerCollection": @@ -259,10 +267,11 @@ def __add__( if not isinstance(other, (Comparer, ComparerCollection)): raise TypeError(f"Cannot add {type(other)} to {type(self)}") - if isinstance(other, Comparer): - return ComparerCollection([*self, other]) - elif isinstance(other, ComparerCollection): - return ComparerCollection([*self, *other]) + return self.merge(other) + # if isinstance(other, Comparer): + # return ComparerCollection([*self, other]) + # elif isinstance(other, ComparerCollection): + # return ComparerCollection([*self, *other]) def sel( self, diff --git a/tests/test_combine_comparers.py b/tests/test_combine_comparers.py index f180e7112..2b29ad79a 100644 --- a/tests/test_combine_comparers.py +++ b/tests/test_combine_comparers.py @@ -137,5 +137,6 @@ def test_concat_time_overlap(o123, mrmike): assert cc2.end_time == cc12.end_time assert cc2.n_points < cc12.n_points - cc12a = cc2.merge(cc26) + # + is supported but not recommended + cc12a = cc2 + cc26 assert cc12a.n_points == cc12.n_points