Skip to content

Commit db2f7c2

Browse files
Merge pull request #447 from DHI/xy-in-comparer
Observation xyz is not included in Comparer for point-to-point comparisons
2 parents 11d50e3 + 8ae039e commit db2f7c2

File tree

2 files changed

+54
-21
lines changed

2 files changed

+54
-21
lines changed

modelskill/matching.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,38 @@
11
from __future__ import annotations
2+
3+
import warnings
24
from datetime import timedelta
35
from pathlib import Path
4-
import warnings
5-
66
from typing import (
7-
Iterable,
7+
Any,
88
Collection,
9+
Iterable,
910
List,
1011
Literal,
1112
Mapping,
1213
Optional,
13-
Union,
1414
Sequence,
15-
get_args,
1615
TypeVar,
17-
Any,
16+
Union,
17+
get_args,
1818
overload,
1919
)
20+
21+
import mikeio
2022
import numpy as np
2123
import pandas as pd
2224
import xarray as xr
2325

24-
import mikeio
25-
26-
27-
from . import model_result, Quantity
28-
from .timeseries import TimeSeries
29-
from .types import Period
26+
from . import Quantity, __version__, model_result
27+
from .comparison import Comparer, ComparerCollection
3028
from .model._base import Alignable
31-
from .model.grid import GridModelResult
3229
from .model.dfsu import DfsuModelResult
33-
from .model.track import TrackModelResult
3430
from .model.dummy import DummyModelResult
31+
from .model.grid import GridModelResult
32+
from .model.track import TrackModelResult
3533
from .obs import Observation, observation
36-
from .comparison import Comparer, ComparerCollection
37-
from . import __version__
34+
from .timeseries import TimeSeries
35+
from .types import Period
3836

3937
TimeDeltaTypes = Union[float, int, np.timedelta64, pd.Timedelta, timedelta]
4038
IdxOrNameTypes = Optional[Union[int, str]]
@@ -403,7 +401,8 @@ def match_space_time(
403401
)
404402
aligned = aligned.rename({v: f"{v}_mod" for v in overlapping_names})
405403

406-
data.update(aligned)
404+
for dv in aligned:
405+
data[dv] = aligned[dv]
407406

408407
# drop NaNs in model and observation columns (but allow NaNs in aux columns)
409408
def mo_kind(k: str) -> bool:

tests/test_match.py

+38-4
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,32 @@ def mr3():
7070
return ms.model_result(fn, item=0, name="SW_3")
7171

7272

73+
def test_properties_after_match(o1, mr1):
74+
cmp = ms.match(o1, mr1)
75+
assert cmp.n_models == 1
76+
assert cmp.n_points == 386
77+
assert cmp.x == 4.242
78+
assert cmp.y == 52.6887
79+
assert cmp.z is None
80+
assert cmp.name == "HKNA"
81+
assert cmp.gtype == "point"
82+
assert cmp.mod_names == ["SW_1"]
83+
84+
85+
def test_properties_after_match_ts(o1):
86+
fn = "tests/testdata/SW/HKNA_Hm0.dfs0"
87+
mr = ms.PointModelResult(fn, item=0, name="SW_1")
88+
cmp = ms.match(o1, mr)
89+
assert cmp.n_models == 1
90+
assert cmp.n_points == 564
91+
assert cmp.x == 4.242
92+
assert cmp.y == 52.6887
93+
assert cmp.z is None
94+
assert cmp.name == "HKNA"
95+
assert cmp.gtype == "point"
96+
assert cmp.mod_names == ["SW_1"]
97+
98+
7399
# TODO remove in v1.1
74100
def test_compare_multi_obs_multi_model_is_deprecated(o1, o2, o3, mr1, mr2):
75101
with pytest.warns(FutureWarning, match="match"):
@@ -543,12 +569,20 @@ def test_compare_model_vs_dummy_for_track(mr1, o3):
543569
# better than dummy 🙂
544570
assert cmp2.score()["SW_1"] == pytest.approx(0.3524703)
545571

546-
def test_match_obs_model_pos_args_wrong_order_helpful_error_message():
547572

573+
def test_match_obs_model_pos_args_wrong_order_helpful_error_message():
548574
# match is pretty helpful in converting strings or dataset
549575
# so we need to use a ModelResult to trigger the error
550-
mr = ms.PointModelResult(data=pd.Series([0.0,0.0], index=pd.date_range("1970", periods=2, freq='d')), name="Zero")
551-
obs = ms.PointObservation(data=pd.Series([1.0, 2.0, 3.0], index=pd.date_range("1970", periods=3, freq='h')), name="MyStation")
576+
mr = ms.PointModelResult(
577+
data=pd.Series([0.0, 0.0], index=pd.date_range("1970", periods=2, freq="d")),
578+
name="Zero",
579+
)
580+
obs = ms.PointObservation(
581+
data=pd.Series(
582+
[1.0, 2.0, 3.0], index=pd.date_range("1970", periods=3, freq="h")
583+
),
584+
name="MyStation",
585+
)
552586

553587
with pytest.raises(TypeError, match="order"):
554-
ms.match(mr, obs)
588+
ms.match(mr, obs)

0 commit comments

Comments
 (0)