Skip to content

FEAT add jac for narx #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Build wheels
uses: pypa/[email protected].1
uses: pypa/[email protected].2
env:
CIBW_BUILD: cp3*-*
CIBW_SKIP: pp* *i686* *musllinux* *-macosx_universal2 *-manylinux_ppc64le *-manylinux_s390x
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ jobs:

steps:
- uses: actions/checkout@v4
- uses: prefix-dev/[email protected].3
- uses: prefix-dev/[email protected].4
with:
environments: default
cache: true

- name: Re-install local
run: |
pixi run rebuild
pixi reinstall --frozen fastcan

- name: Lint with ruff
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:

- name: Re-install local
run: |
pixi run rebuild
pixi reinstall --frozen fastcan

- name: Test with pytest
run: |
Expand Down
9 changes: 6 additions & 3 deletions examples/plot_narx.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,14 @@
# In the printed NARX model, it is found that :class:`FastCan` selects the correct
# terms and the coefficients are close to the true values.

from fastcan.narx import NARX, print_narx
from fastcan.narx import NARX, print_narx, tp2fd

# Convert poly_ids and time_shift_ids to feat_ids and delay_ids
feat_ids, delay_ids = tp2fd(time_shift_ids, selected_poly_ids)

narx_model = NARX(
time_shift_ids=time_shift_ids,
poly_ids=selected_poly_ids,
feat_ids=feat_ids,
delay_ids=delay_ids,
)

narx_model.fit(X, y)
Expand Down
21 changes: 12 additions & 9 deletions examples/plot_narx_msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Nonlinear system
# ----------------
#
# `Duffing equation <https://en.wikipedia.org/wiki/Duffing_equation>` is used to
# `Duffing equation <https://en.wikipedia.org/wiki/Duffing_equation>`_ is used to
# generate simulated data. The mathematical model is given by
#
# .. math::
Expand Down Expand Up @@ -82,15 +82,18 @@ def auto_duffing_equation(y, t):
dur = 10
n_samples = 1000

rng = np.random.default_rng(12345)
e_train = rng.normal(0, 0.0002, n_samples)
e_test = rng.normal(0, 0.0002, n_samples)
t = np.linspace(0, dur, n_samples)

sol = odeint(duffing_equation, [0.6, 0.8], t)
u_train = 2.5 * np.cos(2 * np.pi * t).reshape(-1, 1)
y_train = sol[:, 0]
y_train = sol[:, 0] + e_train

sol = odeint(auto_duffing_equation, [0.6, -0.8], t)
sol = odeint(duffing_equation, [0.6, -0.8], t)
u_test = 2.5 * np.cos(2 * np.pi * t).reshape(-1, 1)
y_test = sol[:, 0]
y_test = sol[:, 0]+ e_test

# %%
# One-step-head VS. multi-step-ahead NARX
Expand All @@ -105,12 +108,12 @@ def auto_duffing_equation(y, t):

from fastcan.narx import make_narx

max_delay = 2
max_delay = 3

narx_model = make_narx(
X=u_train,
y=y_train,
n_terms_to_select=10,
n_terms_to_select=5,
max_delay=max_delay,
poly_degree=3,
verbose=0,
Expand All @@ -130,7 +133,7 @@ def plot_prediction(ax, t, y_true, y_pred, title):
y_train_osa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay])
y_test_osa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay])

narx_model.fit(u_train, y_train, coef_init="one_step_ahead", method="Nelder-Mead")
narx_model.fit(u_train, y_train, coef_init="one_step_ahead")
y_train_msa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay])
y_test_msa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay])

Expand Down Expand Up @@ -159,7 +162,7 @@ def plot_prediction(ax, t, y_true, y_pred, title):
narx_model = make_narx(
X=u_all,
y=y_all,
n_terms_to_select=10,
n_terms_to_select=5,
max_delay=max_delay,
poly_degree=3,
verbose=0,
Expand All @@ -169,7 +172,7 @@ def plot_prediction(ax, t, y_true, y_pred, title):
y_train_osa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay])
y_test_osa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay])

narx_model.fit(u_all, y_all, coef_init="one_step_ahead", method="Nelder-Mead")
narx_model.fit(u_all, y_all, coef_init="one_step_ahead")
y_train_msa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay])
y_test_msa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay])

Expand Down
Loading
Loading