Skip to content

Commit 6efac05

Browse files
Merge pull request #66 from MatthewSZhang/narx-jac
FEAT add jac for narx
2 parents 357b22f + 183d8cc commit 6efac05

File tree

10 files changed

+2243
-1370
lines changed

10 files changed

+2243
-1370
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
steps:
2323
- uses: actions/checkout@v4
2424
- name: Build wheels
25-
uses: pypa/[email protected].1
25+
uses: pypa/[email protected].2
2626
env:
2727
CIBW_BUILD: cp3*-*
2828
CIBW_SKIP: pp* *i686* *musllinux* *-macosx_universal2 *-manylinux_ppc64le *-manylinux_s390x

.github/workflows/lint.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ jobs:
99

1010
steps:
1111
- uses: actions/checkout@v4
12-
- uses: prefix-dev/[email protected].3
12+
- uses: prefix-dev/[email protected].4
1313
with:
1414
environments: default
1515
cache: true
1616

1717
- name: Re-install local
1818
run: |
19-
pixi run rebuild
19+
pixi reinstall --frozen fastcan
2020
2121
- name: Lint with ruff
2222
run: |

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828

2929
- name: Re-install local
3030
run: |
31-
pixi run rebuild
31+
pixi reinstall --frozen fastcan
3232
3333
- name: Test with pytest
3434
run: |

examples/plot_narx.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,14 @@
125125
# In the printed NARX model, it is found that :class:`FastCan` selects the correct
126126
# terms and the coefficients are close to the true values.
127127

128-
from fastcan.narx import NARX, print_narx
128+
from fastcan.narx import NARX, print_narx, tp2fd
129+
130+
# Convert poly_ids and time_shift_ids to feat_ids and delay_ids
131+
feat_ids, delay_ids = tp2fd(time_shift_ids, selected_poly_ids)
129132

130133
narx_model = NARX(
131-
time_shift_ids=time_shift_ids,
132-
poly_ids=selected_poly_ids,
134+
feat_ids=feat_ids,
135+
delay_ids=delay_ids,
133136
)
134137

135138
narx_model.fit(X, y)

examples/plot_narx_msa.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# Nonlinear system
1616
# ----------------
1717
#
18-
# `Duffing equation <https://en.wikipedia.org/wiki/Duffing_equation>` is used to
18+
# `Duffing equation <https://en.wikipedia.org/wiki/Duffing_equation>`_ is used to
1919
# generate simulated data. The mathematical model is given by
2020
#
2121
# .. math::
@@ -82,15 +82,18 @@ def auto_duffing_equation(y, t):
8282
dur = 10
8383
n_samples = 1000
8484

85+
rng = np.random.default_rng(12345)
86+
e_train = rng.normal(0, 0.0002, n_samples)
87+
e_test = rng.normal(0, 0.0002, n_samples)
8588
t = np.linspace(0, dur, n_samples)
8689

8790
sol = odeint(duffing_equation, [0.6, 0.8], t)
8891
u_train = 2.5 * np.cos(2 * np.pi * t).reshape(-1, 1)
89-
y_train = sol[:, 0]
92+
y_train = sol[:, 0] + e_train
9093

91-
sol = odeint(auto_duffing_equation, [0.6, -0.8], t)
94+
sol = odeint(duffing_equation, [0.6, -0.8], t)
9295
u_test = 2.5 * np.cos(2 * np.pi * t).reshape(-1, 1)
93-
y_test = sol[:, 0]
96+
y_test = sol[:, 0]+ e_test
9497

9598
# %%
9699
# One-step-head VS. multi-step-ahead NARX
@@ -105,12 +108,12 @@ def auto_duffing_equation(y, t):
105108

106109
from fastcan.narx import make_narx
107110

108-
max_delay = 2
111+
max_delay = 3
109112

110113
narx_model = make_narx(
111114
X=u_train,
112115
y=y_train,
113-
n_terms_to_select=10,
116+
n_terms_to_select=5,
114117
max_delay=max_delay,
115118
poly_degree=3,
116119
verbose=0,
@@ -130,7 +133,7 @@ def plot_prediction(ax, t, y_true, y_pred, title):
130133
y_train_osa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay])
131134
y_test_osa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay])
132135

133-
narx_model.fit(u_train, y_train, coef_init="one_step_ahead", method="Nelder-Mead")
136+
narx_model.fit(u_train, y_train, coef_init="one_step_ahead")
134137
y_train_msa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay])
135138
y_test_msa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay])
136139

@@ -159,7 +162,7 @@ def plot_prediction(ax, t, y_true, y_pred, title):
159162
narx_model = make_narx(
160163
X=u_all,
161164
y=y_all,
162-
n_terms_to_select=10,
165+
n_terms_to_select=5,
163166
max_delay=max_delay,
164167
poly_degree=3,
165168
verbose=0,
@@ -169,7 +172,7 @@ def plot_prediction(ax, t, y_true, y_pred, title):
169172
y_train_osa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay])
170173
y_test_osa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay])
171174

172-
narx_model.fit(u_all, y_all, coef_init="one_step_ahead", method="Nelder-Mead")
175+
narx_model.fit(u_all, y_all, coef_init="one_step_ahead")
173176
y_train_msa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay])
174177
y_test_msa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay])
175178

0 commit comments

Comments
 (0)