From 22355f0922e6fa7feaf03411668134f35bf19978 Mon Sep 17 00:00:00 2001 From: Manny Hernandez <86986433+m-ny@users.noreply.github.com> Date: Fri, 23 May 2025 02:52:48 +0800 Subject: [PATCH 1/3] Update update_transformation.py --- axlearn/common/update_transformation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/common/update_transformation.py b/axlearn/common/update_transformation.py index 4c36bacce..3653a39b3 100644 --- a/axlearn/common/update_transformation.py +++ b/axlearn/common/update_transformation.py @@ -186,7 +186,7 @@ def real_transform(_): return new_updates.delta_updates, new_state def stop_transform(_): - return jax.tree_map(jnp.zeros_like, updates.delta_updates), prev_state + return jax.tree_util.tree_map(jnp.zeros_like, updates.delta_updates), prev_state # We do the computation regardless of the should_update value, so we could have # equally used jnp.where() here instead. From aee5d378aca81f0c50aff5fd098de3d3e357a5eb Mon Sep 17 00:00:00 2001 From: Manny Hernandez <86986433+m-ny@users.noreply.github.com> Date: Fri, 23 May 2025 02:56:36 +0800 Subject: [PATCH 2/3] In base_layer_test.py: Updated import from jax.core to jax.extend.core Updated import from jax.interpreters.ad to jax.extend.ad Updated the function call from jax.interpreters.ad.deflinear to jax.extend.ad.deflinear --- axlearn/common/base_layer_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/axlearn/common/base_layer_test.py b/axlearn/common/base_layer_test.py index 5d5552bc3..d3de1c1f1 100644 --- a/axlearn/common/base_layer_test.py +++ b/axlearn/common/base_layer_test.py @@ -11,8 +11,8 @@ from unittest import mock import jax.ad_checkpoint -import jax.core -import jax.interpreters.ad +import jax.extend.core +import jax.extend.ad import jax.random import numpy as np from absl.testing import absltest, parameterized @@ -130,7 +130,7 @@ def backward_impl(x): prim = jax.extend.core.Primitive("passthrough_with_callback") prim.def_impl(forward_impl) prim.def_abstract_eval(forward_impl) - jax.interpreters.ad.deflinear(prim, backward_impl) + jax.extend.ad.deflinear(prim, backward_impl) return prim.bind From 4c70a61d5a4b1b9263c7475c8667b34066da6577 Mon Sep 17 00:00:00 2001 From: Manny Hernandez <86986433+m-ny@users.noreply.github.com> Date: Fri, 23 May 2025 02:57:45 +0800 Subject: [PATCH 3/3] Updated import from jax.interpreters.pxla to jax.extend.pxla --- axlearn/common/flash_attention/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/common/flash_attention/layer.py b/axlearn/common/flash_attention/layer.py index ad075528c..c9c0f1332 100644 --- a/axlearn/common/flash_attention/layer.py +++ b/axlearn/common/flash_attention/layer.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import numpy as np from jax.experimental.shard_map import shard_map -from jax.interpreters.pxla import thread_resources +from jax.extend.pxla import thread_resources from jax.sharding import PartitionSpec from axlearn.common.attention import Dropout, ForwardMode, GroupedQueryAttention