diff --git a/axlearn/common/base_layer_test.py b/axlearn/common/base_layer_test.py index 5d5552bc3..d3de1c1f1 100644 --- a/axlearn/common/base_layer_test.py +++ b/axlearn/common/base_layer_test.py @@ -11,8 +11,8 @@ from unittest import mock import jax.ad_checkpoint -import jax.core -import jax.interpreters.ad +import jax.extend.core +import jax.extend.ad import jax.random import numpy as np from absl.testing import absltest, parameterized @@ -130,7 +130,7 @@ def backward_impl(x): prim = jax.extend.core.Primitive("passthrough_with_callback") prim.def_impl(forward_impl) prim.def_abstract_eval(forward_impl) - jax.interpreters.ad.deflinear(prim, backward_impl) + jax.extend.ad.deflinear(prim, backward_impl) return prim.bind diff --git a/axlearn/common/flash_attention/layer.py b/axlearn/common/flash_attention/layer.py index ad075528c..c9c0f1332 100644 --- a/axlearn/common/flash_attention/layer.py +++ b/axlearn/common/flash_attention/layer.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import numpy as np from jax.experimental.shard_map import shard_map -from jax.interpreters.pxla import thread_resources +from jax.extend.pxla import thread_resources from jax.sharding import PartitionSpec from axlearn.common.attention import Dropout, ForwardMode, GroupedQueryAttention diff --git a/axlearn/common/update_transformation.py b/axlearn/common/update_transformation.py index 4c36bacce..3653a39b3 100644 --- a/axlearn/common/update_transformation.py +++ b/axlearn/common/update_transformation.py @@ -186,7 +186,7 @@ def real_transform(_): return new_updates.delta_updates, new_state def stop_transform(_): - return jax.tree_map(jnp.zeros_like, updates.delta_updates), prev_state + return jax.tree_util.tree_map(jnp.zeros_like, updates.delta_updates), prev_state # We do the computation regardless of the should_update value, so we could have # equally used jnp.where() here instead.