Skip to content

Commit 19b54c0

Browse files
committed
Migrate from Legacy JAX APIs jax.tree_util to jax.tree
1 parent 31e8da0 commit 19b54c0

24 files changed

+57
-77
lines changed

axlearn/common/array_serialization.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,7 @@ async def _run_serializer():
378378

379379
asyncio.run(_run_serializer())
380380

381-
self._add_futures(
382-
jax.tree_util.tree_flatten(commit_futures)[0] + (additional_futures or [])
383-
)
381+
self._add_futures(jax.tree.flatten(commit_futures)[0] + (additional_futures or []))
384382

385383
# Used in wait_until_finished to check on process != 0, if the checkpoint
386384
# has finished writing.

axlearn/common/attention_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5161,7 +5161,7 @@ def has_prebuilt_layers(path):
51615161
lambda path, spec: spec if has_prebuilt_layers(path) else None, param_specs
51625162
)
51635163
if prebuilt_layers:
5164-
self.assertNotEmpty(jax.tree_util.tree_leaves(prebuilt_specs))
5164+
self.assertNotEmpty(jax.tree.leaves(prebuilt_specs))
51655165
initialized_state = layer.initialize_parameters_recursively(
51665166
prng_key=jax.random.PRNGKey(123), prebuilt=prebuilt_specs
51675167
)

axlearn/common/checkpointer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -572,9 +572,7 @@ def restore_from_dir(
572572
else:
573573
raise RuntimeError(f"Unknown index entry '{value}'")
574574

575-
restored_state = jax.tree_util.tree_unflatten(
576-
jax.tree_util.tree_structure(state), state_leaves
577-
)
575+
restored_state = jax.tree.unflatten(jax.tree.structure(state), state_leaves)
578576
multihost_utils.sync_global_devices(ckpt_dir)
579577
return restored_state
580578

axlearn/common/checkpointer_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,9 +425,7 @@ def test_custom_dict(self, checkpointer_cls, custom_dict_type):
425425
step, restored_state = ckpt.restore(step=None, state=state0)
426426
self.assertEqual(100, step)
427427
self.assertEqual(type(restored_state), custom_dict_type)
428-
self.assertIn(
429-
custom_dict_type.__name__, str(jax.tree_util.tree_structure(restored_state))
430-
)
428+
self.assertIn(custom_dict_type.__name__, str(jax.tree.structure(restored_state)))
431429
self.assertNestedEqual(state0, restored_state)
432430
ckpt.stop()
433431

axlearn/common/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __call__(self, input_batch: NestedTensor) -> Output:
108108
isinstance(x, jax.Array) and len(x.devices()) == 1
109109
) or isinstance(x, np.ndarray)
110110
all_host_local_inputs = all(
111-
is_host_local_input_check(t) for t in jax.tree_util.tree_leaves(input_batch)
111+
is_host_local_input_check(t) for t in jax.tree.leaves(input_batch)
112112
)
113113

114114
if all_host_local_inputs:

axlearn/common/inference_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def write(self, *, input_batch: NestedTensor, output_batch: NestedTensor):
199199
output_batch: A NestedTensor whose leaves must be tensors of shape [batch_size, ...].
200200
"""
201201
local_data = dict(input=input_batch, output=output_batch)
202-
local_batch_size = jax.tree_util.tree_leaves(local_data)[0].shape[0]
202+
local_batch_size = jax.tree.leaves(local_data)[0].shape[0]
203203

204204
for i in range(local_batch_size):
205205
example = jax.tree.map(lambda x, index=i: x[index], local_data)

axlearn/common/inference_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def run(self, **kwargs):
138138
self.summary_writer(step=batch_index, values=output.summaries)
139139

140140
if (batch_index + 1) % 10 == 0:
141-
global_batch_size = len(jax.tree_util.tree_leaves(global_input_batch)[0])
141+
global_batch_size = len(jax.tree.leaves(global_input_batch)[0])
142142
logging.info(
143143
"Processed %d batches and %d examples",
144144
batch_index + 1,

axlearn/common/learner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,7 @@ def _learner_tree(self, params: Nested[Any]) -> Nested[str]:
427427
tree_paths(params),
428428
)
429429
# Check that all params is covered.
430-
if not jax.tree_util.tree_reduce(
431-
lambda x, y: x and (y != ""), learner_name_tree, initializer=True
432-
):
430+
if not jax.tree.reduce(lambda x, y: x and (y != ""), learner_name_tree, initializer=True):
433431
raise ValueError("Composite learner rules do not update all model params.")
434432
return learner_name_tree
435433

axlearn/common/learner_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutp
165165
self.assertGreater(forward_outputs.aux["discriminator_loss"], 0.0)
166166
# The structure of updated params and Adam mu states are same.
167167
self.assertNestedEqual(
168-
jax.tree_util.tree_structure(updated_model_params),
169-
jax.tree_util.tree_structure(learner_state["optimizer"][1].mu),
168+
jax.tree.structure(updated_model_params),
169+
jax.tree.structure(learner_state["optimizer"][1].mu),
170170
)
171171

172172
@parameterized.product(ema_decay=(None, 0.9), method=("update", "forward_and_backward"))
@@ -983,14 +983,14 @@ def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutp
983983
# The structure of updated params and optimizer states are same.
984984
opt_state_leaf_fn = lambda x: isinstance(x, (Tensor, optax.MaskedNode))
985985
self.assertNestedEqual(
986-
jax.tree_util.tree_structure(updated_model_params),
987-
jax.tree_util.tree_structure(
986+
jax.tree.structure(updated_model_params),
987+
jax.tree.structure(
988988
learner_state["encoder"]["optimizer"][0].trace, is_leaf=opt_state_leaf_fn
989989
),
990990
)
991991
self.assertNestedEqual(
992-
jax.tree_util.tree_structure(updated_model_params),
993-
jax.tree_util.tree_structure(
992+
jax.tree.structure(updated_model_params),
993+
jax.tree.structure(
994994
learner_state["decoder"]["optimizer"][1].mu, is_leaf=opt_state_leaf_fn
995995
),
996996
)
@@ -1156,7 +1156,7 @@ def loss_fn(model_params, inputs):
11561156
summaries={},
11571157
module_outputs={},
11581158
)
1159-
result = jax.tree_util.tree_reduce(lambda x, y: x.sum() + y.sum(), model_params)
1159+
result = jax.tree.reduce(lambda x, y: x.sum() + y.sum(), model_params)
11601160
return ForwardOutputs(loss=result, aux={}, output_collection=output_collection)
11611161

11621162
grads = jax.tree_map(lambda p: jnp.ones_like(p.value), params)

axlearn/common/metrics_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def test_metric_accumulator(self):
5151
)
5252

5353
chex.assert_trees_all_equal_structs(result, expected)
54-
result = jax.tree_util.tree_leaves(result)
55-
expected = jax.tree_util.tree_leaves(expected)
54+
result = jax.tree.leaves(result)
55+
expected = jax.tree.leaves(expected)
5656
chex.assert_trees_all_close(result, expected)
5757

5858
def test_flatten_unflatten_metric_accumulator(self):
@@ -75,10 +75,10 @@ def test_flatten_unflatten_metric_accumulator(self):
7575
for s in summaries_copy:
7676
acc.update(s)
7777

78-
flat, tree = jax.tree_util.tree_flatten(acc)
79-
unflattened = jax.tree_util.tree_unflatten(tree, flat)
80-
expected = jax.tree_util.tree_leaves(acc.summaries())
81-
result = jax.tree_util.tree_leaves(unflattened.summaries())
78+
flat, tree = jax.tree.flatten(acc)
79+
unflattened = jax.tree.unflatten(tree, flat)
80+
expected = jax.tree.leaves(acc.summaries())
81+
result = jax.tree.leaves(unflattened.summaries())
8282
chex.assert_trees_all_close(result, expected)
8383

8484
@parameterized.parameters(

axlearn/common/mixture_of_experts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,7 @@ def convert_fn(source_parameters: Nested[Tensor]) -> Nested[Tensor]:
991991
) from e
992992
# The target layer is a RepeatedTransformerLayer.
993993
target_parameters = {"repeat": VDict({"layer": {}})}
994-
num_stages = jax.tree_util.tree_leaves(stage_parameter_specs)[0].shape[0]
994+
num_stages = jax.tree.leaves(stage_parameter_specs)[0].shape[0]
995995
# The target stage is expected to be a StackedTransformerLayer.
996996
num_layers_per_stage = len(stage_parameter_specs)
997997
for layer_i in range(num_layers_per_stage):

axlearn/common/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def propagate_repeated_output_collections(
202202
# if a repeated layer outputs a scalar summary value, it will have shape [N].
203203
# Below we split the stacked values and output them separately under scope
204204
# "{child_name_prefix}{i}" so that scalar summaries can be handled correctly.
205-
summary_values = jax.tree_util.tree_leaves(repeated_output_collection.summaries)
205+
summary_values = jax.tree.leaves(repeated_output_collection.summaries)
206206
if summary_values:
207207
first_summary_value = summary_values[0]
208208
assert first_summary_value.shape, "Stacked summaries should have a leading stack dimension."

axlearn/common/optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1351,7 +1351,7 @@ def _is_valid_step(
13511351
return is_valid, new_drop_stats
13521352

13531353
# Check if every gradient is finite.
1354-
flat_updates = jax.tree_util.tree_flatten(updates)[0]
1354+
flat_updates = jax.tree.flatten(updates)[0]
13551355
is_finite = jnp.all(jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
13561356
g_norm = optax.global_norm(updates)
13571357
if drop_norm is not None:

axlearn/common/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def _run(
766766
cfg: Pipeline.Config = self.config
767767
self.vlog(1, "carry=%s xs=%s", shapes(carry), shapes(xs))
768768

769-
carry_leaves = jax.tree_util.tree_leaves(carry)
769+
carry_leaves = jax.tree.leaves(carry)
770770
if not carry_leaves:
771771
raise ValueError("Expected at least one input leaf.")
772772
if carry_leaves[0].ndim < 2:

axlearn/common/repeat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def _run(self, fn, carry=None, *, xs=None):
190190
with child_context("layer", output_collection=layer_output_collection) as layer_context:
191191
# Note, actual `num_layers` might be smaller than `cfg.num_layers` depending on
192192
# the invocation context.
193-
num_layers = jax.tree_util.tree_reduce(
193+
num_layers = jax.tree.reduce(
194194
lambda num, x: min(num, x.shape[0]),
195195
tree=(layer_context.state, xs),
196196
initializer=cfg.num_layers,

axlearn/common/rnn_test.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,9 @@ def test_repeat_forward_vs_layerwise(self, norm_cfg, hidden_dim, num_layers):
185185
final_states_list.append(output_collections.module_outputs["final_states"])
186186

187187
# Stack the tree leaves.
188-
tree_leaves = [jax.tree_util.tree_flatten(t)[0] for t in final_states_list]
189-
tree_def = jax.tree_util.tree_structure(final_states_list[0])
190-
final_states = jax.tree_util.tree_unflatten(
191-
tree_def, [jnp.stack(leaf) for leaf in zip(*tree_leaves)]
192-
)
188+
tree_leaves = [jax.tree.flatten(t)[0] for t in final_states_list]
189+
tree_def = jax.tree.structure(final_states_list[0])
190+
final_states = jax.tree.unflatten(tree_def, [jnp.stack(leaf) for leaf in zip(*tree_leaves)])
193191
self.assertEqual(shapes(final_states), shapes(init_states))
194192

195193
forward_outputs, forward_collections = F(

axlearn/common/state_builder_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -673,9 +673,7 @@ def _run_builder(
673673
**extra_converter_config_kwargs,
674674
):
675675
source_state = _mock_state(source_cfg, seed=0)
676-
initial_trainer_state_tree_structure = jax.tree_util.tree_structure(
677-
source_state.trainer_state
678-
)
676+
initial_trainer_state_tree_structure = jax.tree.structure(source_state.trainer_state)
679677

680678
builder = (
681679
builder_cls.default_config()
@@ -689,7 +687,7 @@ def _run_builder(
689687
source_model = source_state.trainer_state.model
690688

691689
converted_state = builder(deepcopy(source_state))
692-
assert initial_trainer_state_tree_structure == jax.tree_util.tree_structure(
690+
assert initial_trainer_state_tree_structure == jax.tree.structure(
693691
converted_state.trainer_state
694692
)
695693
converted_model = converted_state.trainer_state.model

axlearn/common/struct_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class SlotsPoint:
5858

5959
def test_pytree_nodes(self):
6060
p = _Point(x=1, y=2, meta={"abc": True})
61-
leaves = jax.tree_util.tree_leaves(p)
61+
leaves = jax.tree.leaves(p)
6262
self.assertEqual(leaves, [1, 2])
6363
new_p = jax.tree.map(lambda x: x + x, p)
6464
self.assertEqual(new_p, _Point(x=2, y=4, meta={"abc": True}))
@@ -104,7 +104,7 @@ def test_chex_tree_leaves_compatibility(self):
104104
)
105105
# tree_flatten_with_path is not preserved because Chex does not support this so the
106106
# fallback jax implementation with numbered keys gets used.
107-
flattened.append(jax.tree_util.tree_leaves(instance))
107+
flattened.append(jax.tree.leaves(instance))
108108
chex.assert_trees_all_equal(*flattened)
109109

110110
def test_constructor_order(self):
@@ -133,7 +133,7 @@ class C:
133133
field_b: int
134134
field_a: int
135135

136-
result = jax.tree_util.tree_leaves(C(field_b=1, field_a=2))
136+
result = jax.tree.leaves(C(field_b=1, field_a=2))
137137
expected = (1, 2)
138138
self.assertSequenceEqual(result, expected)
139139

axlearn/common/test_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ def _compute_layer_outputs(
196196
# Optionally, test that trees also have the same structure.
197197
if require_same_tree_structure:
198198
# Prune empty subtrees so we don't require empty dicts for layers with no params.
199-
ref_structure = jax.tree_util.tree_structure(prune_empty(params_from_ref))
200-
test_structure = jax.tree_util.tree_structure(prune_empty(layer_params))
199+
ref_structure = jax.tree.structure(prune_empty(params_from_ref))
200+
test_structure = jax.tree.structure(prune_empty(layer_params))
201201
self.assertEqual(
202202
ref_structure, test_structure, msg=f"\nRef: {ref_structure}\nTest: {test_structure}"
203203
)
@@ -428,8 +428,8 @@ def replace_keys(v, mapping):
428428
params_with_nones = jax.tree_map(
429429
partial(replace_keys, mapping={k: None for k in delegates}), params, is_leaf=is_leaf
430430
)
431-
_, treedef = jax.tree_util.tree_flatten(params_with_nones)
432-
inits_with_nones = jax.tree_util.tree_unflatten(treedef, param_init_specs)
431+
_, treedef = jax.tree.flatten(params_with_nones)
432+
inits_with_nones = jax.tree.unflatten(treedef, param_init_specs)
433433

434434
# Replace the Nones with a delegate.
435435
return jax.tree_map(partial(replace_keys, mapping=delegates), inits_with_nones, is_leaf=is_leaf)
@@ -563,9 +563,7 @@ def patched_register_per_param_settings(
563563
model_params = model.initialize_parameters_recursively(jax.random.PRNGKey(0))
564564

565565
model_specs = model.create_parameter_specs_recursively()
566-
model_specs = complete_partition_spec_tree(
567-
jax.tree_util.tree_structure(model_params), model_specs
568-
)
566+
model_specs = complete_partition_spec_tree(jax.tree.structure(model_params), model_specs)
569567
opt_params = jax.tree.map(
570568
lambda param, spec: OptParam(
571569
value=param,

axlearn/common/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def _opt_params(self, model_params: NestedTensor) -> NestedOptParam:
604604
"""Returns a tree of OptParam for Learner.{init,update}."""
605605
# self._model_param_specs can be incomplete. Complete it first.
606606
specs = utils.complete_partition_spec_tree(
607-
jax.tree_util.tree_structure(model_params), self._model_param_specs
607+
jax.tree.structure(model_params), self._model_param_specs
608608
)
609609
return jax.tree.map(
610610
lambda param, spec: OptParam(
@@ -852,7 +852,7 @@ def _prepare_training(self, prng_key: Tensor) -> bool:
852852
# Log trainer state tree.
853853
if not self.step and jax.process_index() == 0:
854854
with fs.open(os.path.join(cfg.dir, "trainer_state_tree.txt"), "w") as f:
855-
f.write(str(jax.tree_util.tree_structure(self._trainer_state)))
855+
f.write(str(jax.tree.structure(self._trainer_state)))
856856

857857
with fs.open(os.path.join(cfg.dir, "model_analysis.txt"), "w") as f:
858858
f.write(model_analysis)

axlearn/common/trainer_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -579,9 +579,7 @@ def test_compile_train_step(self, *, platform, mesh_shape):
579579
trainer: SpmdTrainer = cfg.instantiate(parent=None)
580580
compiled_without_args = trainer.compile_train_step()
581581
# pylint: disable=protected-access
582-
input_batch = jax.tree_util.tree_map(
583-
jnp.array, next(trainer.input.batches(trainer._input_iter))
584-
)
582+
input_batch = jax.tree.map(jnp.array, next(trainer.input.batches(trainer._input_iter)))
585583
# pylint: enable=protected-access
586584
compiled_with_input_batch = trainer.compile_train_step(input_batch=input_batch)
587585
# In a single-host environment, both compiled functions should match.

axlearn/common/utils.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def vectorized_tree_map(fn, tree, *rest):
441441

442442
def vectorized_fn(*nodes):
443443
if isinstance(nodes[0], VDict):
444-
if not jax.tree_util.tree_leaves(nodes[0]):
444+
if not jax.tree.leaves(nodes[0]):
445445
# This can happen when all VDict values are None and cause issues with jax.vmap.
446446
return nodes[0]
447447
nodes = [dict(**node) for node in nodes]
@@ -469,7 +469,7 @@ def fn(value: Union[Tensor, VDict]) -> NestedTensor:
469469
if not isinstance(value, VDict):
470470
return value
471471

472-
leaves = jax.tree_util.tree_leaves(value)
472+
leaves = jax.tree.leaves(value)
473473
if not leaves:
474474
# An empty VDict.
475475
return value
@@ -653,7 +653,7 @@ def complete_partition_spec_tree(
653653
prefix of treedef.
654654
"""
655655
proxy = object()
656-
dummy = jax.tree_util.tree_unflatten(treedef, [object()] * treedef.num_leaves)
656+
dummy = jax.tree.unflatten(treedef, [object()] * treedef.num_leaves)
657657
axes = []
658658

659659
def replace_none_with_proxy(tree):
@@ -672,17 +672,17 @@ def replace_none_with_proxy(tree):
672672
partition_spec_tree_with_proxy = replace_none_with_proxy(partition_spec_tree)
673673

674674
def add_leaves(i, x):
675-
axes.extend([i] * len(jax.tree_util.tree_flatten(x)[0]))
675+
axes.extend([i] * len(jax.tree.flatten(x)[0]))
676676

677677
try:
678678
jax.tree.map(add_leaves, partition_spec_tree_with_proxy, dummy)
679679
except ValueError as err:
680680
logging.info("[complete_partition_spec_tree] ValueError: %s", err)
681681
logging.info(
682682
"[complete_partition_spec_tree] partition_spec_tree_with_proxy=%s",
683-
jax.tree_util.tree_structure(partition_spec_tree_with_proxy),
683+
jax.tree.structure(partition_spec_tree_with_proxy),
684684
)
685-
logging.info("[complete_partition_spec_tree] dummy=%s", jax.tree_util.tree_structure(dummy))
685+
logging.info("[complete_partition_spec_tree] dummy=%s", jax.tree.structure(dummy))
686686
for path, value in flatten_items(partition_spec_tree_with_proxy):
687687
logging.info(
688688
"[complete_partition_spec_tree] partition_spec_tree_with_proxy leaf: %s=%s",
@@ -701,7 +701,7 @@ def add_leaves(i, x):
701701
assert (
702702
len(axes) == treedef.num_leaves
703703
), f"({len(axes)} vs. {treedef.num_leaves}) {axes} {treedef}"
704-
return jax.tree_util.tree_unflatten(treedef, axes)
704+
return jax.tree.unflatten(treedef, axes)
705705

706706

707707
def input_partition_spec() -> PartitionSpec:
@@ -801,9 +801,7 @@ def host_to_global_device_array(
801801
"""
802802
mesh = thread_resources.env.physical_mesh
803803
partition_spec = data_partition_type_to_spec(partition)
804-
partition_specs = complete_partition_spec_tree(
805-
jax.tree_util.tree_structure(host_arrays), partition_spec
806-
)
804+
partition_specs = complete_partition_spec_tree(jax.tree.structure(host_arrays), partition_spec)
807805
process_count = jax.process_count()
808806

809807
def make_gda(x, partition_spec):
@@ -1031,7 +1029,7 @@ def cast(x: Union[Tensor, TensorSpec]) -> Union[Tensor, TensorSpec]:
10311029

10321030
def count_model_params(tree: NestedTensor) -> int:
10331031
"""Count the number of parameters in a model."""
1034-
return sum(x.size for x in jax.tree_util.tree_leaves(tree))
1032+
return sum(x.size for x in jax.tree.leaves(tree))
10351033

10361034

10371035
def check_param_shape_alignment(
@@ -1095,7 +1093,7 @@ def check_jax_type(
10951093
pretty_named_args.update({f"kwargs[{key}]": kwargs[key] for key in kwargs})
10961094

10971095
for name, arg in pretty_named_args.items():
1098-
values, _ = jax.tree_util.tree_flatten(arg)
1096+
values, _ = jax.tree.flatten(arg)
10991097
for value in values:
11001098
if not isinstance(value, (type(None), jax.Array, int, float)):
11011099
if msg is None:

0 commit comments

Comments
 (0)