|
27 | 27 | from ._fastcan import FastCan
|
28 | 28 | from ._narx_fast import _predict_step, _update_cfd, _update_terms # type: ignore
|
29 | 29 | from ._refine import refine
|
| 30 | +from .utils import mask_missing_values |
30 | 31 |
|
31 | 32 |
|
32 | 33 | @validate_params(
|
@@ -273,14 +274,6 @@ def make_poly_ids(
|
273 | 274 | return np.delete(ids, const_id, 0) # remove the constant featrue
|
274 | 275 |
|
275 | 276 |
|
276 |
| -def _mask_missing_value(*arr, return_mask=False): |
277 |
| - """Remove missing value for all arrays.""" |
278 |
| - mask_nomissing = np.all(np.isfinite(np.c_[arr]), axis=1) |
279 |
| - if return_mask: |
280 |
| - return mask_nomissing |
281 |
| - return tuple([x[mask_nomissing] for x in arr]) |
282 |
| - |
283 |
| - |
284 | 277 | def _valiate_time_shift_poly_ids(
|
285 | 278 | time_shift_ids, poly_ids, n_samples=None, n_features=None, n_outputs=None
|
286 | 279 | ):
|
@@ -374,7 +367,7 @@ def _validate_feat_delay_ids(
|
374 | 367 | )
|
375 | 368 | if (delay_ids_.min() < -1) or (delay_ids_.max() >= n_samples):
|
376 | 369 | raise ValueError(
|
377 |
| - "The element x of delay_ids should " f"satisfy -1 <= x < {n_samples}." |
| 370 | + f"The element x of delay_ids should satisfy -1 <= x < {n_samples}." |
378 | 371 | )
|
379 | 372 | return feat_ids_, delay_ids_
|
380 | 373 |
|
@@ -783,7 +776,7 @@ def fit(self, X, y, sample_weight=None, coef_init=None, **params):
|
783 | 776 | time_shift_vars = make_time_shift_features(xy_hstack, time_shift_ids)
|
784 | 777 | poly_terms = make_poly_features(time_shift_vars, poly_ids)
|
785 | 778 | # Remove missing values
|
786 |
| - poly_terms_masked, y_masked, sample_weight_masked = _mask_missing_value( |
| 779 | + poly_terms_masked, y_masked, sample_weight_masked = mask_missing_values( |
787 | 780 | poly_terms, y, sample_weight
|
788 | 781 | )
|
789 | 782 | coef = np.zeros(n_terms, dtype=float)
|
@@ -1060,7 +1053,7 @@ def _loss(
|
1060 | 1053 | output_ids,
|
1061 | 1054 | )
|
1062 | 1055 |
|
1063 |
| - y_masked, y_hat_masked, sample_weight_sqrt_masked = _mask_missing_value( |
| 1056 | + y_masked, y_hat_masked, sample_weight_sqrt_masked = mask_missing_values( |
1064 | 1057 | y, y_hat, sample_weight_sqrt
|
1065 | 1058 | )
|
1066 | 1059 |
|
@@ -1115,12 +1108,10 @@ def _grad(
|
1115 | 1108 | grad_delay_ids,
|
1116 | 1109 | )
|
1117 | 1110 |
|
1118 |
| - mask_nomissing = _mask_missing_value( |
1119 |
| - y, y_hat, sample_weight_sqrt, return_mask=True |
1120 |
| - ) |
| 1111 | + mask_valid = mask_missing_values(y, y_hat, sample_weight_sqrt, return_mask=True) |
1121 | 1112 |
|
1122 |
| - sample_weight_sqrt_masked = sample_weight_sqrt[mask_nomissing] |
1123 |
| - dydx_masked = dydx[mask_nomissing] |
| 1113 | + sample_weight_sqrt_masked = sample_weight_sqrt[mask_valid] |
| 1114 | + dydx_masked = dydx[mask_valid] |
1124 | 1115 |
|
1125 | 1116 | return dydx_masked.sum(axis=1) * sample_weight_sqrt_masked
|
1126 | 1117 |
|
@@ -1264,7 +1255,7 @@ def _get_term_str(term_feat_ids, term_delay_ids):
|
1264 | 1255 | else:
|
1265 | 1256 | term_str += f"*X[k-{delay_id},{feat_id}]"
|
1266 | 1257 | elif feat_id >= narx.n_features_in_:
|
1267 |
| - term_str += f"*y_hat[k-{delay_id},{feat_id-narx.n_features_in_}]" |
| 1258 | + term_str += f"*y_hat[k-{delay_id},{feat_id - narx.n_features_in_}]" |
1268 | 1259 | return term_str[1:]
|
1269 | 1260 |
|
1270 | 1261 | yid_space = 5
|
@@ -1472,7 +1463,7 @@ def make_narx(
|
1472 | 1463 | poly_terms = make_poly_features(time_shift_vars, poly_ids_all)
|
1473 | 1464 |
|
1474 | 1465 | # Remove missing values
|
1475 |
| - poly_terms_masked, y_masked = _mask_missing_value(poly_terms, y) |
| 1466 | + poly_terms_masked, y_masked = mask_missing_values(poly_terms, y) |
1476 | 1467 |
|
1477 | 1468 | selected_poly_ids = []
|
1478 | 1469 | for i in range(n_outputs):
|
|
0 commit comments