Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Support python 3.12 and misc #205

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
dfd7c8f
Update setup.py
romanngg Mar 8, 2024
c460bed
Update setup.py
romanngg Mar 8, 2024
5434520
Update linux.yml
romanngg Mar 8, 2024
93100cc
Update macos.yml
romanngg Mar 8, 2024
ae1bede
Update setup.py
romanngg Mar 8, 2024
bb18562
Update setup.py
romanngg Mar 14, 2024
266b8ec
Update setup.py
romanngg Mar 14, 2024
821960f
Update setup.py
romanngg Mar 14, 2024
e824bf9
Update setup.py
romanngg Mar 14, 2024
a4130ac
Update setup.py
romanngg Mar 15, 2024
433f82e
Update empirical.py
romanngg Mar 15, 2024
aff65c1
Update empirical_tf_test.py
romanngg Mar 15, 2024
f928978
Update setup.py
romanngg Mar 15, 2024
fb29308
Update setup.py
romanngg Mar 19, 2024
a6b2cd9
Update empirical_tf_test.py
romanngg Mar 19, 2024
387d32f
Update empirical_tf_test.py
romanngg Mar 28, 2024
f31a2d9
Update empirical_ntk_tf_test.py
romanngg Mar 28, 2024
3cc684d
Update empirical_ntk_tf_test.py
romanngg Mar 28, 2024
2256a35
Update setup.py
romanngg Mar 28, 2024
83bf6f0
Update setup.py
romanngg Mar 28, 2024
04b03b5
Update linux.yml
romanngg Mar 28, 2024
8028008
Update macos.yml
romanngg Mar 28, 2024
46f03fa
Update setup.py
romanngg Jul 19, 2024
73fd32a
Update linux.yml
romanngg Jul 19, 2024
dd709d8
Update macos.yml
romanngg Jul 19, 2024
e4901cf
Update pytype.yml
romanngg Jul 19, 2024
abee391
Update sketching.yml
romanngg Jul 19, 2024
63e257c
Update predict.py
romanngg Sep 2, 2024
c365aeb
Update setup.py
romanngg Sep 2, 2024
d4b1df2
Update setup.py
romanngg Sep 2, 2024
ab78a32
Update predict.py
romanngg Sep 2, 2024
449f752
Update predict.py
romanngg Sep 2, 2024
9443dcb
Misc
romanngg Sep 3, 2024
fe7a397
Try a stable release of `tf2jax`
romanngg Sep 3, 2024
6f4bc34
Require `tensorflow>=2.16.2` for macos
romanngg Sep 3, 2024
4fec634
Update jax
romanngg Sep 12, 2024
11df20e
Merge remote-tracking branch 'origin/main'
romanngg Sep 12, 2024
a6a8cdf
No-op
romanngg Sep 12, 2024
9d8361a
No-op
romanngg Sep 12, 2024
41aeb11
Merge remote-tracking branch 'origin/main'
romanngg Sep 12, 2024
ecf1de0
Update jax
romanngg Sep 12, 2024
6476351
Merge remote-tracking branch 'origin/main'
romanngg Sep 12, 2024
b2b4005
Update macos runner
romanngg Sep 12, 2024
ee923b9
Undo macos runner
romanngg Sep 12, 2024
471e3a9
Merge remote-tracking branch 'origin/main'
romanngg Sep 12, 2024
ac477d8
Disable TF tests on macos
romanngg Sep 12, 2024
88dfe39
Update versions
romanngg Oct 7, 2024
a017fab
Update versions
romanngg Oct 7, 2024
efa98dc
Update versions
romanngg Oct 24, 2024
2da21b1
Update versions
romanngg Oct 24, 2024
f74fe40
Update versions
romanngg Oct 24, 2024
e2e3a25
Update versions
romanngg Oct 24, 2024
20572f2
Update versions
romanngg Oct 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:

strategy:
matrix:
python-version: [3.9, '3.10', 3.11]
python-version: ['3.10', 3.11, 3.12]
JAX_ENABLE_X64: [0, 1]

runs-on: ubuntu-latest
Expand All @@ -32,10 +32,10 @@ jobs:
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v4.1.1
- uses: actions/checkout@v4.2.1

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5.0.0
uses: actions/setup-python@v5.2.0
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -54,7 +54,7 @@ jobs:
JAX_ENABLE_X64=${{ matrix.JAX_ENABLE_X64 }} PYTHONHASHSEED=0 pytest -n auto --cov=neural_tangents --cov-report=xml --cov-report=term

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4.0.1
uses: codecov/codecov-action@v4.6.0
with:
file: ./coverage.xml

Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:

strategy:
matrix:
python-version: [3.9, '3.10', 3.11]
python-version: ['3.10', 3.11, 3.12]
JAX_ENABLE_X64: [0]

runs-on: macos-latest
Expand All @@ -32,10 +32,10 @@ jobs:
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v4.1.1
- uses: actions/checkout@v4.2.1

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5.0.0
uses: actions/setup-python@v5.2.0
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -54,7 +54,7 @@ jobs:
JAX_ENABLE_X64=${{ matrix.JAX_ENABLE_X64 }} PYTHONHASHSEED=0 pytest -n auto --cov=neural_tangents --cov-report=xml --cov-report=term

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4.0.1
uses: codecov/codecov-action@v4.6.0
with:
file: ./coverage.xml

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/pytype.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ jobs:
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v4.1.1
- uses: actions/checkout@v4.2.1

- name: Set up Python 3.10
uses: actions/setup-python@v5.0.0
uses: actions/setup-python@v5.2.0
with:
python-version: '3.10'

Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/sketching.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ jobs:
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v4.1.1
- uses: actions/checkout@v4.2.1

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5.0.0
uses: actions/setup-python@v5.2.0
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -53,7 +53,7 @@ jobs:
JAX_ENABLE_X64=${{ matrix.JAX_ENABLE_X64 }} PYTHONHASHSEED=0 pytest experimental/tests/ -n auto --cov=experimental/ --cov-report=xml --cov-report=term

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4.0.1
uses: codecov/codecov-action@v4.6.0
with:
file: ./coverage.xml

Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ from jax.example_libraries import stax
init_fn, apply_fn = stax.serial(
stax.Dense(512), stax.Relu,
stax.Dense(512), stax.Relu,
stax.Dense(1)
stax.Dense(1),
)

key = random.PRNGKey(1)
Expand All @@ -123,7 +123,7 @@ from neural_tangents import stax
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(512), stax.Relu(),
stax.Dense(512), stax.Relu(),
stax.Dense(1)
stax.Dense(1),
)

key1, key2 = random.split(random.PRNGKey(1))
Expand Down
21 changes: 10 additions & 11 deletions examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_dataset(
permute_train=False,
do_flatten_and_normalize=True,
data_dir=None,
input_key='image'
input_key='image',
):
"""Download, parse and process a dataset to unit scale and one-hot labels."""
# Need this following http://cl/378185881 to prevent GPU test breakages.
Expand Down Expand Up @@ -133,9 +133,10 @@ def embed_glove(xs, glove_path, max_sentence_length=1000, mask_constant=1000.):
xs = list(map(_decode, xs))
tokenizer = tf.keras.preprocessing.text.Tokenizer()
tokenizer.fit_on_texts(np.concatenate(xs))
glove_embedding_layer = _get_glove_embedding_layer(tokenizer,
glove_path,
max_sentence_length)
glove_embedding_layer = _get_glove_embedding_layer(
tokenizer,
glove_path,
)

def embed(x):
# Replace strings with sequences of integer tokens.
Expand All @@ -147,7 +148,8 @@ def embed(x):
x_tok,
max_sentence_length,
padding='post',
truncating='post')
truncating='post',
)

# Replace integer tokens with word embeddings.
x_emb = glove_embedding_layer(x_tok).numpy()
Expand All @@ -160,7 +162,7 @@ def embed(x):
return map(embed, xs)


def _get_glove_embedding_layer(tokenizer, glove_path, max_sentence_length):
def _get_glove_embedding_layer(tokenizer, glove_path):
"""Get a Keras embedding layer for a given GloVe embeddings.

Adapted from https://keras.io/examples/pretrained_word_embeddings/.
Expand All @@ -172,9 +174,6 @@ def _get_glove_embedding_layer(tokenizer, glove_path, max_sentence_length):
glove_path:
path to the GloVe embedding file.

max_sentence_length:
pad/truncate embeddings to this length.

Returns:
Keras embedding layer for a given GloVe embeddings.
"""
Expand Down Expand Up @@ -212,8 +211,8 @@ def _get_glove_embedding_layer(tokenizer, glove_path, max_sentence_length):
embedding_layer = tf.keras.layers.Embedding(
num_words, embedding_dim,
embeddings_initializer=tf.keras.initializers.Constant(emb_mat),
input_length=max_sentence_length,
trainable=False)
trainable=False,
)

return embedding_layer

Expand Down
2 changes: 1 addition & 1 deletion examples/function_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def main(unused_argv):
opt_init, opt_apply, get_params = optimizers.sgd(_LEARNING_RATE)
state = opt_init(params)

# Create an mse loss function and a gradient function.
# Create a mse loss function and a gradient function.
loss = lambda fx, y_hat: 0.5 * jnp.mean((fx - y_hat) ** 2)
grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

Expand Down
2 changes: 1 addition & 1 deletion examples/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def main(*args, use_dummy_data: bool = False, **kwargs) -> None:


def _get_dummy_data(
mask_constant: float
mask_constant: float,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Return dummy data for when downloading embeddings is not feasible."""
n_train, n_test = 6, 6
Expand Down
Loading
Loading