Skip to content

Commit 727a676

Browse files
committed
MNT output proper constant when some outputs only have intercepts
1 parent 820f4e6 commit 727a676

File tree

3 files changed

+124
-162
lines changed

3 files changed

+124
-162
lines changed

fastcan/narx.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,8 @@ def fit(self, X, y, sample_weight=None, coef_init=None, **params):
760760
warnings.warn(
761761
f"output_ids got {self.output_ids_}, which does not "
762762
f"contain all values from 0 to {self.n_outputs_ - 1}."
763-
"The predicted outputs for the missing values will be 0.",
763+
"The prediction for the missing outputs will be a constant"
764+
"(i.e., intercept).",
764765
UserWarning,
765766
)
766767

@@ -783,6 +784,7 @@ def fit(self, X, y, sample_weight=None, coef_init=None, **params):
783784
for i in range(self.n_outputs_):
784785
output_i_mask = self.output_ids_ == i
785786
if np.sum(output_i_mask) == 0:
787+
intercept[i] = np.mean(y_masked[:, i])
786788
continue
787789
osa_narx.fit(
788790
poly_terms_masked[:, output_i_mask],
@@ -974,8 +976,8 @@ def _update_dydx(
974976
dydx[k, y_ids, x_ids] = terms
975977

976978
# Update dynamic terms of Jacobian
977-
cfd = np.zeros((n_y, n_y, max_delay), dtype=float)
978-
if max_delay > 0:
979+
if max_delay > 0 and grad_yyd_ids.size > 0:
980+
cfd = np.zeros((n_y, n_y, max_delay), dtype=float)
979981
_update_cfd(
980982
X,
981983
y_hat,

0 commit comments

Comments
 (0)