Skip to content

Commit 4006c59

Browse files
committed
FEAT add mask_missing_values in utils
1 parent c939d8c commit 4006c59

File tree

5 files changed

+426
-400
lines changed

5 files changed

+426
-400
lines changed

fastcan/narx.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ._fastcan import FastCan
2828
from ._narx_fast import _predict_step, _update_cfd, _update_terms # type: ignore
2929
from ._refine import refine
30+
from .utils import mask_missing_values
3031

3132

3233
@validate_params(
@@ -273,14 +274,6 @@ def make_poly_ids(
273274
return np.delete(ids, const_id, 0) # remove the constant featrue
274275

275276

276-
def _mask_missing_value(*arr, return_mask=False):
277-
"""Remove missing value for all arrays."""
278-
mask_nomissing = np.all(np.isfinite(np.c_[arr]), axis=1)
279-
if return_mask:
280-
return mask_nomissing
281-
return tuple([x[mask_nomissing] for x in arr])
282-
283-
284277
def _valiate_time_shift_poly_ids(
285278
time_shift_ids, poly_ids, n_samples=None, n_features=None, n_outputs=None
286279
):
@@ -374,7 +367,7 @@ def _validate_feat_delay_ids(
374367
)
375368
if (delay_ids_.min() < -1) or (delay_ids_.max() >= n_samples):
376369
raise ValueError(
377-
"The element x of delay_ids should " f"satisfy -1 <= x < {n_samples}."
370+
f"The element x of delay_ids should satisfy -1 <= x < {n_samples}."
378371
)
379372
return feat_ids_, delay_ids_
380373

@@ -783,7 +776,7 @@ def fit(self, X, y, sample_weight=None, coef_init=None, **params):
783776
time_shift_vars = make_time_shift_features(xy_hstack, time_shift_ids)
784777
poly_terms = make_poly_features(time_shift_vars, poly_ids)
785778
# Remove missing values
786-
poly_terms_masked, y_masked, sample_weight_masked = _mask_missing_value(
779+
poly_terms_masked, y_masked, sample_weight_masked = mask_missing_values(
787780
poly_terms, y, sample_weight
788781
)
789782
coef = np.zeros(n_terms, dtype=float)
@@ -1060,7 +1053,7 @@ def _loss(
10601053
output_ids,
10611054
)
10621055

1063-
y_masked, y_hat_masked, sample_weight_sqrt_masked = _mask_missing_value(
1056+
y_masked, y_hat_masked, sample_weight_sqrt_masked = mask_missing_values(
10641057
y, y_hat, sample_weight_sqrt
10651058
)
10661059

@@ -1115,12 +1108,10 @@ def _grad(
11151108
grad_delay_ids,
11161109
)
11171110

1118-
mask_nomissing = _mask_missing_value(
1119-
y, y_hat, sample_weight_sqrt, return_mask=True
1120-
)
1111+
mask_valid = mask_missing_values(y, y_hat, sample_weight_sqrt, return_mask=True)
11211112

1122-
sample_weight_sqrt_masked = sample_weight_sqrt[mask_nomissing]
1123-
dydx_masked = dydx[mask_nomissing]
1113+
sample_weight_sqrt_masked = sample_weight_sqrt[mask_valid]
1114+
dydx_masked = dydx[mask_valid]
11241115

11251116
return dydx_masked.sum(axis=1) * sample_weight_sqrt_masked
11261117

@@ -1264,7 +1255,7 @@ def _get_term_str(term_feat_ids, term_delay_ids):
12641255
else:
12651256
term_str += f"*X[k-{delay_id},{feat_id}]"
12661257
elif feat_id >= narx.n_features_in_:
1267-
term_str += f"*y_hat[k-{delay_id},{feat_id-narx.n_features_in_}]"
1258+
term_str += f"*y_hat[k-{delay_id},{feat_id - narx.n_features_in_}]"
12681259
return term_str[1:]
12691260

12701261
yid_space = 5
@@ -1472,7 +1463,7 @@ def make_narx(
14721463
poly_terms = make_poly_features(time_shift_vars, poly_ids_all)
14731464

14741465
# Remove missing values
1475-
poly_terms_masked, y_masked = _mask_missing_value(poly_terms, y)
1466+
poly_terms_masked, y_masked = mask_missing_values(poly_terms, y)
14761467

14771468
selected_poly_ids = []
14781469
for i in range(n_outputs):

fastcan/utils.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99
from sklearn.cross_decomposition import CCA
10-
from sklearn.utils import check_X_y
10+
from sklearn.utils import _safe_indexing, check_consistent_length, check_X_y
1111
from sklearn.utils._param_validation import Interval, validate_params
1212

1313

@@ -120,3 +120,52 @@ def ols(X, y, t=1):
120120
if not mask[j]:
121121
w[:, j] = w[:, j] - w[:, d] * (w[:, d] @ w[:, j])
122122
w[:, j] /= np.linalg.norm(w[:, j], axis=0)
123+
124+
125+
@validate_params(
126+
{
127+
"return_mask": ["boolean"],
128+
},
129+
prefer_skip_nested_validation=True,
130+
)
131+
def mask_missing_values(*arrays, return_mask=False):
132+
"""Remove missing values for all arrays.
133+
134+
Parameters
135+
----------
136+
*arrays : sequence of array-like of shape (n_samples,) or \
137+
(n_samples, n_outputs)
138+
Arrays with consistent first dimension.
139+
140+
return_mask : bool, default=False
141+
If True, return a mask of valid values.
142+
If False, return the arrays with missing values removed.
143+
144+
Returns
145+
-------
146+
mask_valid : ndarray of shape (n_samples,)
147+
Mask of valid values.
148+
149+
masked_arrays : sequence of array-like of shape (n_samples,) or \
150+
(n_samples, n_outputs)
151+
Arrays with missing values removed.
152+
The order of the arrays is the same as the input arrays.
153+
154+
Examples
155+
--------
156+
>>> import numpy as np
157+
>>> from fastcan.utils import mask_missing_values
158+
>>> a = [[1, 2], [3, np.nan], [5, 6]]
159+
>>> b = [1, 2, 3]
160+
>>> mask_missing_values(a, b)
161+
[[[1, 2], [5, 6]], [1, 3]]
162+
>>> mask_missing_values(a, b, return_mask=True)
163+
array([ True, False, True])
164+
"""
165+
if len(arrays) == 0:
166+
return None
167+
check_consistent_length(*arrays)
168+
mask_valid = np.all(np.isfinite(np.c_[arrays]), axis=1)
169+
if return_mask:
170+
return mask_valid
171+
return [_safe_indexing(x, mask_valid) for x in arrays]

0 commit comments

Comments
 (0)