Skip to content

Commit c9b87b1

Browse files
address review comments
1 parent 46db09b commit c9b87b1

File tree

4 files changed

+23
-65
lines changed

4 files changed

+23
-65
lines changed

.github/workflows/actions.yml

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
name: Tests
22

3+
# TODO: Consider enabling all tests (pytest, applications, etc.) with NNX in the future
4+
# Currently only basic flow tests run with NNX enabled
5+
36
on:
47
push:
58
branches: [ master ]
@@ -17,11 +20,18 @@ jobs:
1720
matrix:
1821
python-version: ['3.10']
1922
backend: [tensorflow, jax, torch, numpy, openvino]
23+
nnx_enabled: [false]
24+
include:
25+
- python-version: '3.10'
26+
backend: jax
27+
nnx_enabled: true
2028
name: Run tests
2129
runs-on: ubuntu-latest
2230
env:
2331
PYTHON: ${{ matrix.python-version }}
2432
KERAS_HOME: .github/workflows/config/${{ matrix.backend }}
33+
KERAS_BACKEND: ${{ matrix.backend }}
34+
KERAS_NNX_ENABLED: ${{ matrix.nnx_enabled }}
2535
steps:
2636
- uses: actions/checkout@v4
2737
- name: Check for changes in keras/src/applications
@@ -48,6 +58,9 @@ jobs:
4858
- name: Install dependencies
4959
run: |
5060
pip install -r requirements.txt --progress-bar off --upgrade
61+
if [ "${{ matrix.nnx_enabled }}" == "true" ]; then
62+
pip install --upgrade git+https://github.com/divyashreepathihalli/flax.git@use-id
63+
fi
5164
pip uninstall -y keras keras-nightly
5265
pip install -e "." --progress-bar off --upgrade
5366
- name: Test applications with pytest
@@ -73,6 +86,11 @@ jobs:
7386
if: ${{ matrix.backend == 'jax'}}
7487
run: |
7588
python integration_tests/jax_custom_fit_test.py
89+
- name: Test basic flow with NNX
90+
if: ${{ matrix.nnx_enabled == 'true'}}
91+
run: |
92+
python integration_tests/import_test.py
93+
python integration_tests/basic_full_flow.py
7694
- name: Test TF-specific integrations
7795
if: ${{ matrix.backend == 'tensorflow'}}
7896
run: |
@@ -96,8 +114,8 @@ jobs:
96114
- name: Codecov keras
97115
uses: codecov/codecov-action@v5
98116
with:
99-
env_vars: PYTHON,KERAS_HOME
100-
flags: keras,keras-${{ matrix.backend }}
117+
env_vars: PYTHON,KERAS_HOME,KERAS_BACKEND,KERAS_NNX_ENABLED
118+
flags: keras,keras-${{ matrix.backend }}${{ matrix.nnx_enabled == 'true' && '-nnx' || '' }}
101119
files: core-coverage.xml
102120
token: ${{ secrets.CODECOV_TOKEN }}
103121
fail_ci_if_error: false

.github/workflows/config/jax/keras.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
"floatx": "float32",
33
"epsilon": 1e-07,
44
"backend": "jax",
5-
"image_data_format": "channels_last"
5+
"image_data_format": "channels_last",
6+
"nnx_enabled": false
67
}

.github/workflows/nnx-tests.yml

Lines changed: 0 additions & 61 deletions
This file was deleted.

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ torch-xla==2.6.0;sys_platform != 'darwin'
1414
# Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test.
1515
# Note that we test against the latest JAX on GPU.
1616
jax[cpu]==0.5.0
17-
flax>=0.10.1
17+
flax
1818
# Common deps.
1919
-r requirements-common.txt

0 commit comments

Comments
 (0)