Skip to content

Commit a5c90bc

Browse files
Update cmethods.adjust's type annotations
1 parent 3addde7 commit a5c90bc

File tree

6 files changed

+39
-32
lines changed

6 files changed

+39
-32
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
repos:
77
- repo: https://github.com/astral-sh/ruff-pre-commit
8-
rev: v0.7.0
8+
rev: v0.9.3
99
hooks:
1010
- id: ruff
1111
args:
@@ -23,12 +23,12 @@ repos:
2323
# - --install-types
2424
# - --non-interactive
2525
- repo: https://github.com/codespell-project/codespell
26-
rev: v2.3.0
26+
rev: v2.4.0
2727
hooks:
2828
- id: codespell
2929
additional_dependencies: [tomli]
3030
- repo: https://github.com/pre-commit/pre-commit-hooks
31-
rev: v4.6.0
31+
rev: v5.0.0
3232
hooks:
3333
# all available hooks can be found here: https://github.com/pre-commit/pre-commit-hooks/blob/main/.pre-commit-hooks.yaml
3434
- id: check-yaml
@@ -72,7 +72,7 @@ repos:
7272
- id: isort
7373
args: [--profile=black]
7474
- repo: https://github.com/PyCQA/bandit
75-
rev: 1.7.10
75+
rev: 1.8.2
7676
hooks:
7777
- id: bandit
7878
exclude: "^tests/.*|examples/.*"

Makefile

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,16 @@ install:
3939
test:
4040
$(PYTHON) -m pytest $(PYTEST_OPTS) $(TESTS)
4141

42-
.PHONY: tests
42+
.PHONY: test
4343
tests: test
4444

45+
## retest Rerun tests that failed before
46+
##
47+
.PHONY: retest
48+
retest:
49+
$(PYTHON) -m pytest $(PYTEST_OPTS) --lf $(TESTS)
50+
51+
4552
## wip Run tests marked as wip
4653
##
4754
.PHONY: wip

cmethods/core.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from cmethods.scaling import linear_scaling as __linear_scaling
2222
from cmethods.scaling import variance_scaling as __variance_scaling
2323
from cmethods.static import SCALING_METHODS
24-
from cmethods.utils import UnknownMethodError, check_xr_types
24+
from cmethods.utils import UnknownMethodError, ensure_xr_dataarray
2525

2626
if TYPE_CHECKING:
2727
from cmethods.types import XRData
@@ -37,16 +37,16 @@
3737

3838
def apply_ufunc(
3939
method: str,
40-
obs: XRData,
41-
simh: XRData,
42-
simp: XRData,
40+
obs: xr.xarray.core.dataarray.DataArray,
41+
simh: xr.xarray.core.dataarray.DataArray,
42+
simp: xr.xarray.core.dataarray.DataArray,
4343
**kwargs: dict,
44-
) -> XRData:
44+
) -> xr.xarray.core.dataarray.DataArray:
4545
"""
4646
Internal function used to apply the bias correction technique to the
4747
passed input data.
4848
"""
49-
check_xr_types(obs=obs, simh=simh, simp=simp)
49+
ensure_xr_dataarray(obs=obs, simh=simh, simp=simp)
5050
if method not in __METHODS_FUNC__:
5151
raise UnknownMethodError(method, __METHODS_FUNC__.keys())
5252

@@ -96,11 +96,11 @@ def apply_ufunc(
9696

9797
def adjust(
9898
method: str,
99-
obs: XRData,
100-
simh: XRData,
101-
simp: XRData,
99+
obs: xr.xarray.core.dataarray.DataArray,
100+
simh: xr.xarray.core.dataarray.DataArray,
101+
simp: xr.xarray.core.dataarray.DataArray,
102102
**kwargs,
103-
) -> XRData:
103+
) -> xr.xarray.core.dataarray.DataArray | xr.xarray.core.dataarray.Dataset:
104104
"""
105105
Function to apply a bias correction technique on single and multidimensional
106106
data sets. For more information please refer to the method specific
@@ -119,19 +119,19 @@ def adjust(
119119
:param method: Technique to apply
120120
:type method: str
121121
:param obs: The reference/observational data set
122-
:type obs: XRData
122+
:type obs: xr.xarray.core.dataarray.DataArray
123123
:param simh: The modeled data of the control period
124-
:type simh: XRData
124+
:type simh: xr.xarray.core.dataarray.DataArray
125125
:param simp: The modeled data of the period to adjust
126-
:type simp: XRData
126+
:type simp: xr.xarray.core.dataarray.DataArray
127127
:param kwargs: Any other method-specific parameter (like
128128
``n_quantiles`` and ``kind``)
129129
:type kwargs: dict
130130
:return: The bias corrected/adjusted data set
131-
:rtype: XRData
131+
:rtype: xr.xarray.core.dataarray.DataArray | xr.xarray.core.dataarray.Dataset
132132
"""
133133
kwargs["adjust_called"] = True
134-
check_xr_types(obs=obs, simh=simh, simp=simp)
134+
ensure_xr_dataarray(obs=obs, simh=simh, simp=simp)
135135

136136
if method == "detrended_quantile_mapping": # noqa: PLR2004
137137
raise ValueError(

cmethods/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ def check_adjust_called(
5151
)
5252

5353

54-
def check_xr_types(obs: XRData, simh: XRData, simp: XRData) -> None:
54+
def ensure_xr_dataarray(obs: XRData, simh: XRData, simp: XRData) -> None:
5555
"""
5656
Checks if the parameters are in the correct type. **only used internally**
5757
"""
58-
phrase: str = "must be type xarray.core.dataarray.Dataset or xarray.core.dataarray.DataArray"
58+
phrase: str = "must be type 'xarray.core.dataarray.DataArray'."
5959

6060
if not isinstance(obs, XRData_t):
6161
raise TypeError(f"'obs' {phrase}")
@@ -73,7 +73,7 @@ def check_np_types(
7373
"""
7474
Checks if the parameters are in the correct type. **only used internally**
7575
"""
76-
phrase: str = "must be type list, np.ndarray or np.generic"
76+
phrase: str = "must be type list, np.ndarray, or np.generic"
7777

7878
if not isinstance(obs, NPData_t):
7979
raise TypeError(f"'obs' {phrase}")
@@ -246,8 +246,8 @@ def get_adjusted_scaling_factor(
246246
"UnknownMethodError",
247247
"check_adjust_called",
248248
"check_np_types",
249-
"check_xr_types",
250249
"ensure_dividable",
250+
"ensure_xr_dataarray",
251251
"get_adjusted_scaling_factor",
252252
"get_cdf",
253253
"get_inverse_of_cdf",

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ test = [
112112
"scikit-learn",
113113
"scipy",
114114
]
115-
examples = ["click", "matplotlib"]
115+
examples = ["matplotlib"]
116116

117117
[tool.codespell]
118118
check-filenames = true

tests/test_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from cmethods.static import MAX_SCALING_FACTOR
2626
from cmethods.utils import (
2727
check_np_types,
28-
check_xr_types,
2928
ensure_dividable,
29+
ensure_xr_dataarray,
3030
get_adjusted_scaling_factor,
3131
get_pdf,
3232
nan_or_equal,
@@ -133,7 +133,7 @@ def test_xr_type_check() -> None:
133133
correct. No error should occur.
134134
"""
135135
ds: xr.core.dataarray.Dataset = xr.core.dataarray.Dataset()
136-
check_xr_types(obs=ds, simh=ds, simp=ds)
136+
ensure_xr_dataarray(obs=ds, simh=ds, simp=ds)
137137

138138

139139
def test_type_check_failing() -> None:
@@ -142,7 +142,7 @@ def test_type_check_failing() -> None:
142142
have the correct type.
143143
"""
144144

145-
phrase: str = "must be type list, np.ndarray or np.generic"
145+
phrase: str = "must be type list, np.ndarray, or np.generic"
146146
with pytest.raises(TypeError, match=f"'obs' {phrase}"):
147147
check_np_types(obs=1, simh=[], simp=[])
148148

@@ -177,7 +177,7 @@ def test_detrended_quantile_mapping_type_check_simp_failing(datasets: dict) -> N
177177
"""n_quantiles must by type int"""
178178
with pytest.raises(
179179
TypeError,
180-
match="'simp' must be type xarray.core.dataarray.DataArray",
180+
match=r"'simp' must be type xarray.core.dataarray.DataArray",
181181
):
182182
detrended_quantile_mapping( # type: ignore[attr-defined]
183183
obs=datasets["+"]["obsh"][:, 0, 0],
@@ -222,7 +222,7 @@ def test_adjust_type_checking_failing() -> None:
222222
)
223223
with pytest.raises(
224224
TypeError,
225-
match="'obs' must be type xarray.core.dataarray.Dataset or xarray.core.dataarray.DataArray",
225+
match=r"'obs' must be type 'xarray.core.dataarray.DataArray'.",
226226
):
227227
adjust(
228228
method="linear_scaling",
@@ -233,7 +233,7 @@ def test_adjust_type_checking_failing() -> None:
233233
)
234234
with pytest.raises(
235235
TypeError,
236-
match="'simh' must be type xarray.core.dataarray.Dataset or xarray.core.dataarray.DataArray",
236+
match=r"'simh' must be type 'xarray.core.dataarray.DataArray'.",
237237
):
238238
adjust(
239239
method="linear_scaling",
@@ -245,7 +245,7 @@ def test_adjust_type_checking_failing() -> None:
245245

246246
with pytest.raises(
247247
TypeError,
248-
match="'simp' must be type xarray.core.dataarray.Dataset or xarray.core.dataarray.DataArray",
248+
match=r"'simp' must be type 'xarray.core.dataarray.DataArray'.",
249249
):
250250
adjust(
251251
method="linear_scaling",

0 commit comments

Comments
 (0)