Skip to content

Commit 93a60cb

Browse files
committed
TST test printed message
1 parent 55f6e35 commit 93a60cb

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
steps:
2323
- uses: actions/checkout@v4
2424
- name: Build wheels
25-
uses: pypa/[email protected].2
25+
uses: pypa/[email protected].3
2626
env:
2727
CIBW_BUILD: cp3*-*
2828
CIBW_SKIP: pp* *i686* *musllinux* *-macosx_universal2 *-manylinux_ppc64le *-manylinux_s390x

tests/test_narx.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,6 @@ def test_narx(nan, multi_output):
178178
y_hat = narx_array_init_msa.predict(X, y_init=y_init)
179179
assert_array_equal(y_hat[:narx_array_init_msa.max_delay_], y_init)
180180

181-
print_narx(narx_array_init_msa)
182-
183181
with pytest.raises(ValueError, match=r"`coef_init` should have the shape of .*"):
184182
narx_array_init_msa.fit(X, y, coef_init=np.zeros(narx_osa_msa_coef.size))
185183

@@ -492,3 +490,43 @@ def test_tp2fd():
492490
poly_ids[-1][-1] = 5
493491
with pytest.raises(ValueError, match=r"The element x of poly_ids should.*"):
494492
_, _ = tp2fd(time_shift_ids, poly_ids)
493+
494+
def test_print_narx(capsys):
495+
X = np.random.rand(10, 2)
496+
y = np.random.rand(10, 2)
497+
feat_ids = np.array([[0, 1], [1, 2]])
498+
delay_ids = np.array([[1, 0], [2, 2]])
499+
500+
narx = NARX(
501+
feat_ids=feat_ids,
502+
delay_ids=delay_ids,
503+
output_ids=[0, 1],
504+
)
505+
narx.fit(X, y)
506+
print_narx(narx)
507+
captured = capsys.readouterr()
508+
# Check if the header is present in the output
509+
assert "| yid | Term | Coef |" in captured.out
510+
# Check if the intercept line for yid 0 is present
511+
assert "| 0 | Intercept |" in captured.out
512+
# Check if the intercept line for yid 1 is present
513+
assert "| 1 | Intercept |" in captured.out
514+
# Check if the term line for yid 0 is present
515+
assert "| 0 | X[k-1,0]*X[k,1] |" in captured.out
516+
# Check if the term line for yid 1 is present
517+
assert "| 1 |X[k-2,1]*y_hat[k-2,0]|" in captured.out
518+
519+
520+
def test_make_narx_refine_print(capsys):
521+
X = np.random.rand(10, 2)
522+
y = np.random.rand(10, 2)
523+
_ = make_narx(
524+
X,
525+
y,
526+
n_terms_to_select=2,
527+
max_delay=2,
528+
poly_degree=2,
529+
refine_drop=1,
530+
)
531+
captured = capsys.readouterr()
532+
assert "No. of iterations: " in captured.out

0 commit comments

Comments
 (0)