From cbdcde5e745ba29ef84eb82e11f2aaa7af24dca1 Mon Sep 17 00:00:00 2001 From: yiping-ma Date: Thu, 19 Jun 2025 14:15:01 +0000 Subject: [PATCH] Add opensource MoE switch model for TPU v6e testing --- axlearn/common/mixture_of_experts.py | 52 +- axlearn/common/mixture_of_experts_test.py | 59 +- .../envy-Switch-Base-single-host.txt | 496 ++++++++++++++++ .../envy-Switch-Base-single-host_init.txt | 23 + ...vy-Switch-Base-single-host_regularizer.txt | 24 + .../envy-Switch-Base.txt | 493 ++++++++++++++++ .../envy-Switch-Base_init.txt | 23 + .../envy-Switch-Base_regularizer.txt | 24 + .../envy-Switch-Large.txt | 508 ++++++++++++++++ .../envy-Switch-Large_init.txt | 23 + .../envy-Switch-Large_regularizer.txt | 24 + .../envy-Switch-XXL.txt | 469 +++++++++++++++ .../envy-Switch-XXL_init.txt | 23 + .../envy-Switch-XXL_regularizer.txt | 24 + .../envy-test.txt | 401 +++++++++++++ .../envy-test_init.txt | 21 + .../envy-test_regularizer.txt | 22 + .../fuji-70B-v1-flash-fp8.txt | 16 +- .../fuji-70B-v1-flash.txt | 16 +- .../fuji-70B-v1-fp8.txt | 16 +- .../fuji-70B-v1.txt | 16 +- .../fuji-70B-v2-flash-fp8.txt | 16 +- .../fuji-70B-v2-flash.txt | 16 +- .../fuji-70B-v2-fp8.txt | 16 +- .../fuji-70B-v2.txt | 16 +- .../fuji-70B-v3-flash-fp8.txt | 16 +- .../fuji-70B-v3-flash.txt | 16 +- .../fuji-70B-v3-fp8.txt | 16 +- .../fuji-70B-v3-tiktoken-flash-fp8.txt | 16 +- .../fuji-70B-v3-tiktoken-flash.txt | 16 +- .../fuji-70B-v3-tiktoken-fp8.txt | 16 +- .../fuji-70B-v3-tiktoken.txt | 16 +- .../fuji-70B-v3.txt | 16 +- .../fuji-7B-v1-flash-fp8-single-host.txt | 14 +- .../fuji-7B-v1-flash-fp8.txt | 14 +- .../fuji-7B-v1-flash-single-host.txt | 14 +- .../fuji-7B-v1-flash.txt | 14 +- .../fuji-7B-v1-fp8-single-host.txt | 14 +- .../fuji-7B-v1-fp8.txt | 14 +- .../fuji-7B-v1-single-host.txt | 14 +- .../fuji-7B-v1.txt | 14 +- .../fuji-7B-v2-flash-fp8-single-host.txt | 14 +- .../fuji-7B-v2-flash-fp8.txt | 14 +- .../fuji-7B-v2-flash-single-host.txt | 14 +- .../fuji-7B-v2-flash.txt | 14 +- .../fuji-7B-v2-fp8-single-host.txt | 14 +- .../fuji-7B-v2-fp8.txt | 14 +- .../fuji-7B-v2-single-host.txt | 14 +- .../fuji-7B-v2.txt | 14 +- .../fuji-7B-v3-flash-fp8-single-host.txt | 14 +- .../fuji-7B-v3-flash-fp8.txt | 14 +- .../fuji-7B-v3-flash-single-host.txt | 14 +- .../fuji-7B-v3-flash.txt | 14 +- .../fuji-7B-v3-fp8-single-host.txt | 14 +- .../fuji-7B-v3-fp8.txt | 14 +- .../fuji-7B-v3-single-host.txt | 14 +- .../fuji-7B-v3.txt | 14 +- ...i-8B-v3-tiktoken-flash-fp8-single-host.txt | 8 +- .../fuji-8B-v3-tiktoken-flash-fp8.txt | 8 +- .../fuji-8B-v3-tiktoken-flash-single-host.txt | 8 +- .../fuji-8B-v3-tiktoken-flash.txt | 8 +- .../fuji-8B-v3-tiktoken-fp8-single-host.txt | 8 +- .../fuji-8B-v3-tiktoken-fp8.txt | 8 +- .../fuji-8B-v3-tiktoken-single-host.txt | 8 +- .../fuji-8B-v3-tiktoken.txt | 8 +- axlearn/experiments/text/gpt/c4_trainer.py | 3 +- axlearn/experiments/text/gpt/common.py | 84 ++- axlearn/experiments/text/gpt/envy.py | 540 ++++++++++++++++++ axlearn/experiments/text/gpt/fuji.py | 54 +- 69 files changed, 3493 insertions(+), 553 deletions(-) create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base-single-host.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base-single-host_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base-single-host_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Large.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Large_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Large_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-XXL.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-XXL_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-XXL_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-test.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-test_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-test_regularizer.txt create mode 100644 axlearn/experiments/text/gpt/envy.py diff --git a/axlearn/common/mixture_of_experts.py b/axlearn/common/mixture_of_experts.py index 8d5cf12f0..eea475462 100644 --- a/axlearn/common/mixture_of_experts.py +++ b/axlearn/common/mixture_of_experts.py @@ -15,7 +15,8 @@ """ import re -from typing import NamedTuple, Optional, Union +from functools import reduce +from typing import NamedTuple, Optional, Sequence, Union import jax import jax.numpy as jnp @@ -45,6 +46,8 @@ from axlearn.common.param_init import FanAxes, constant_initializer from axlearn.common.quantized_dot_general.layers import DenseGeneralBaseLayer from axlearn.common.utils import ( + HybridMeshShape, + MeshShape, Nested, NestedTensor, PartitionSpec, @@ -52,6 +55,7 @@ VDict, flatten_items, get_recursively, + infer_mesh_shape, set_recursively, tree_paths, with_sharding_constraint, @@ -167,6 +171,52 @@ def _cap_logits(logits: Tensor, gating_logit_cap: float) -> Tensor: return logits +def get_outer_batch_from_mesh( + *, + mesh_axis_names: Sequence[str], + outer_batch_axis_names: Sequence[str], + mesh_shape: Optional[Union[MeshShape, HybridMeshShape]], +) -> Optional[int]: + """Infer MoE outer batch size from mesh shape. + + Args: + mesh_axis_names: The name of each mesh axis. + outer_batch_axis_names: The names of the mesh axes corresponding to the outer batch size. + mesh_shape: The size of each mesh axis corresponding to `mesh_axis_names`. + If None, the returned outer batch size will also be None. + + Returns: + The MoE outer batch size. Will be None if `mesh_shape` is None. + """ + if mesh_shape is None: + return None + + ici_mesh_shape = ( + mesh_shape.ici_mesh_shape if isinstance(mesh_shape, HybridMeshShape) else mesh_shape + ) + try: + ici_mesh_shape = infer_mesh_shape(ici_mesh_shape) + except ValueError as e: + # It could happen when running in local, the number of devices can be smaller than the + # required number of devices from the mesh shape. + logging.info(e) + + if isinstance(mesh_shape, HybridMeshShape): + if -1 in mesh_shape.dcn_mesh_shape: + # TODO(markblee): Improve support for this. At the moment it is not a use-case. + raise NotImplementedError( + "Unable to infer number of granules. Please specify dcn_mesh_shape without -1." + ) + mesh_shape = tuple(x * y for x, y in zip(ici_mesh_shape, mesh_shape.dcn_mesh_shape)) + else: + mesh_shape = ici_mesh_shape + + return reduce( + lambda x, y: x * y, + [mesh_shape[mesh_axis_names.index(el)] for el in outer_batch_axis_names], + ) + + class AdaptiveLoadBalanceLoss(BaseLayer): """A layer to adjust the aux loss weight based on the overcapacity ratio. diff --git a/axlearn/common/mixture_of_experts_test.py b/axlearn/common/mixture_of_experts_test.py index e6c420a75..199c12670 100644 --- a/axlearn/common/mixture_of_experts_test.py +++ b/axlearn/common/mixture_of_experts_test.py @@ -35,6 +35,7 @@ TransformerFeedForwardMoE, _convert_feedforward_to_moe_parameters, convert_dense_to_moe_parameters, + get_outer_batch_from_mesh, ) from axlearn.common.module import functional as F from axlearn.common.quantized_dot_general.activation_clipping import TanhActivationClippingLayer @@ -43,7 +44,14 @@ QuantizedDotGeneral, ) from axlearn.common.test_utils import TestCase, assert_allclose -from axlearn.common.utils import get_recursively, set_recursively, shapes +from axlearn.common.utils import ( + HybridMeshShape, + MeshShape, + get_recursively, + infer_mesh_shape, + set_recursively, + shapes, +) # pylint: disable=no-self-use,protected-access @@ -846,5 +854,54 @@ def test_dense_to_moe_parameters(self): assert_allclose(outputs_dense.data, outputs_moe.data) +class GetOuterBatchFromMeshTest(absltest.TestCase): + """Tests get_outer_batch_from_mesh.""" + + def test_mesh_shape_is_none(self): + result = get_outer_batch_from_mesh( + mesh_axis_names=["data", "model"], + outer_batch_axis_names=["data"], + mesh_shape=None, + ) + self.assertIsNone(result) + + def test_regular_mesh_shape(self): + mesh_shape: MeshShape = (2, 4, 8) + result = get_outer_batch_from_mesh( + mesh_axis_names=["data", "fsdp", "model"], + outer_batch_axis_names=["data", "fsdp"], + mesh_shape=mesh_shape, + ) + self.assertEqual(result, 2 * 4) + + def test_outer_batch_with_hybrid_mesh_shape(self): + hybrid = HybridMeshShape(ici_mesh_shape=(2, 8, 1), dcn_mesh_shape=(4, 1, 1)) + result = get_outer_batch_from_mesh( + mesh_axis_names=["data", "model"], + outer_batch_axis_names=["data"], + mesh_shape=hybrid, + ) + self.assertEqual(result, 2 * 4) + + def test_outer_batch_with_hybrid_mesh_shape_with_raises(self): + hybrid = HybridMeshShape(ici_mesh_shape=(2, 4, 1), dcn_mesh_shape=(-1, 2, 1)) + with self.assertRaisesRegex(NotImplementedError, "Unable to infer number of granules"): + get_outer_batch_from_mesh( + mesh_axis_names=["data", "fsdp", "model"], + outer_batch_axis_names=["data", "fsdp"], + mesh_shape=hybrid, + ) + + def test_outer_batch_with_infer_mesh_shape(self): + mesh_shape: MeshShape = (-1, 2, 4) + inferred_shape = infer_mesh_shape(mesh_shape, num_devices=8) + result = get_outer_batch_from_mesh( + mesh_axis_names=["data", "fsdp", "model"], + outer_batch_axis_names=["data", "fsdp"], + mesh_shape=inferred_shape, + ) + self.assertEqual(result, 1 * 2) + + if __name__ == "__main__": absltest.main() diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base-single-host.txt new file mode 100644 index 000000000..5d9f40c66 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base-single-host.txt @@ -0,0 +1,496 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 5000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 250000 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 250000 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 25000 +evalers['train'].input.batcher.feed_batch_size: 8 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 250000 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 25000 +evalers['validation'].input.batcher.feed_batch_size: 8 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.feed_batch_size: 8 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 1024 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].adam_update_transformation.fn: 'axlearn.common.optimizers.scale_update_per_param' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.default_scale: 1.0 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.description: 'scale_by_mup_simple' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.fn: 'axlearn.common.optimizers.per_param_scale_by_path' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[0][0]: '.*attention/o_proj/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[0][1]: 0.5 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[1][0]: '.*attention/i_proj/i_proj/qkv_proj/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[1][1]: 0.5 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[2][0]: '.*feed_forward/linear1_0/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[2][1]: 0.5 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[3][0]: '.*feed_forward/linear1_1/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[3][1]: 0.5 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[4][0]: '.*feed_forward/linear2/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[4][1]: 0.5 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[5][0]: '.*feed_forward/wi_0_weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[5][1]: 0.5 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[6][0]: '.*feed_forward/wi_1_weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[6][1]: 0.5 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[7][0]: '.*feed_forward/wo_weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[7][1]: 0.5 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.01 +learner.optimizer.args[1].update_schedule.alpha: 0.005 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 250000 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 5000 +learner.optimizer.args[1].weight_decay: 0.0001 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 250000 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v5p-(1024|2048)' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 16 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 16 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[1][0]: 'tpu-v6e-256' +mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 16 +mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 16 +mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' +mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: 1 +mesh_shape[2]: 16 +mesh_shape[3]: -1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 1536 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0][0]: 'expert' +model.decoder.emb.token_emb.param_partition_spec[0][1]: 'fsdp' +model.decoder.emb.token_emb.param_partition_spec[0][2]: 'seq' +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.StackedTransformerLayer' +model.decoder.transformer.layer.layer[0].feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.layer[0].feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.layer[0].feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[0].feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.layer[0].feed_forward.hidden_dim.round_up_to_multiples_of: 128 +model.decoder.transformer.layer.layer[0].feed_forward.hidden_dim.scale: 4 +model.decoder.transformer.layer.layer[0].feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.bias: False +model.decoder.transformer.layer.layer[0].feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.bias: False +model.decoder.transformer.layer.layer[0].feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.layer[0].feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[0].feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[0].feed_forward.structure: 'hybridnorm' +model.decoder.transformer.layer.layer[0].klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.layer[0].remat_spec['prevent_cse']: False +model.decoder.transformer.layer.layer[0].remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.layer[0].remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.layer[0].remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.layer[0].remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.layer[0].remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.layer[0].self_attention.attention.causal: True +model.decoder.transformer.layer.layer[0].self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.num_kv_heads: 12 +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.num_heads: 12 +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.layer[0].self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[0].self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.layer[0].self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[0].self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[0].self_attention.structure: 'prenorm' +model.decoder.transformer.layer.layer[1].feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.layer[1].feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['me'][0]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['me'][1]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['emh'][0]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['emh'][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['emh'][2]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ehm'][0]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ehm'][1]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ehm'][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][1]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][3]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][1]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][3]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][4]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][1]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][3]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][4]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][1]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][2]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][3]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][4]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][1]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][3]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][4]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[1].feed_forward.gating.eval_capacity_factor: 2.0 +model.decoder.transformer.layer.layer[1].feed_forward.gating.gating_logit_cap: 0.0 +model.decoder.transformer.layer.layer[1].feed_forward.gating.klass: 'axlearn.common.mixture_of_experts.Top2Gating' +model.decoder.transformer.layer.layer[1].feed_forward.gating.mask_dtype: 'jax.numpy.int32' +model.decoder.transformer.layer.layer[1].feed_forward.gating.train_capacity_factor: 2.0 +model.decoder.transformer.layer.layer[1].feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.layer[1].feed_forward.hidden_dim.round_up_to_multiples_of: 128 +model.decoder.transformer.layer.layer[1].feed_forward.hidden_dim.scale: 4 +model.decoder.transformer.layer.layer[1].feed_forward.input_dim: 1536 +model.decoder.transformer.layer.layer[1].feed_forward.klass: 'axlearn.common.mixture_of_experts.TransformerFeedForwardMoE' +model.decoder.transformer.layer.layer[1].feed_forward.load_balance_loss_weight: 0.01 +model.decoder.transformer.layer.layer[1].feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].feed_forward.num_experts: 128 +model.decoder.transformer.layer.layer[1].feed_forward.num_groups: 2 +model.decoder.transformer.layer.layer[1].feed_forward.outer_batch: -1 +model.decoder.transformer.layer.layer[1].feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.layer[1].feed_forward.router_z_loss_weight: 0.0 +model.decoder.transformer.layer.layer[1].feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[1].feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[1].feed_forward.structure: 'hybridnorm' +model.decoder.transformer.layer.layer[1].klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.layer[1].remat_spec['prevent_cse']: False +model.decoder.transformer.layer.layer[1].remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.layer[1].remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.layer[1].remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.layer[1].remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.layer[1].remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.layer[1].self_attention.attention.causal: True +model.decoder.transformer.layer.layer[1].self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.num_kv_heads: 12 +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.num_heads: 12 +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.layer[1].self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[1].self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.layer[1].self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[1].self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[1].self_attention.structure: 'prenorm' +model.decoder.transformer.layer.num_layers: 2 +model.decoder.transformer.layer.remat_spec['prevent_cse']: True +model.decoder.transformer.layer.remat_spec['policy']: 'jax._src.ad_checkpoint.nothing_saveable' +model.decoder.transformer.num_layers: 6 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['lm'].z_loss_scale: 1e-06 +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base-single-host_init.txt new file mode 100644 index 000000000..c6255c7d1 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base-single-host_init.txt @@ -0,0 +1,23 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 1536], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer0/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(1536, 36, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/layer0/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(1536, 12, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_query/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_key/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/feed_forward/prenorm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/feed_forward/postnorm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(1536, 6144), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(1536, 6144), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer0/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(6144, 1536), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer1/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(1536, 36, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/layer1/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(1536, 12, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_query/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_key/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/feed_forward/gate_weight: normal(0, 1.0 / fan_in), shape=(1536, 128), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer1/feed_forward/wo_weight: normal(0, 1.0 / fan_in), shape=(128, 6144, 1536), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=0) +decoder/transformer/repeat/layer/layer1/feed_forward/wi_0_weight: normal(0, 1.0 / fan_in), shape=(128, 1536, 6144), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=0) +decoder/transformer/repeat/layer/layer1/feed_forward/wi_1_weight: normal(0, 1.0 / fan_in), shape=(128, 1536, 6144), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=0) +decoder/transformer/repeat/layer/layer1/feed_forward/prenorm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/feed_forward/postnorm/scale: constant(1.0) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base-single-host_regularizer.txt new file mode 100644 index 000000000..f218e7067 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base-single-host_regularizer.txt @@ -0,0 +1,24 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/postnorm/scale: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/prenorm/scale: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_key/norm/scale: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_query/norm/scale: 1 +decoder/transformer/repeat/layer/layer0/self_attention/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/gate_weight: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/postnorm/scale: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/prenorm/scale: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/wi_0_weight: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/wi_1_weight: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/wo_weight: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_key/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_query/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base.txt new file mode 100644 index 000000000..09d5f6d86 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base.txt @@ -0,0 +1,493 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 5000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 250000 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 250000 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 25000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 250000 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 25000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 1024 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].adam_update_transformation.fn: 'axlearn.common.optimizers.scale_update_per_param' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.default_scale: 1.0 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.description: 'scale_by_mup_simple' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.fn: 'axlearn.common.optimizers.per_param_scale_by_path' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[0][0]: '.*attention/o_proj/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[0][1]: 0.5 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[1][0]: '.*attention/i_proj/i_proj/qkv_proj/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[1][1]: 0.5 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[2][0]: '.*feed_forward/linear1_0/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[2][1]: 0.5 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[3][0]: '.*feed_forward/linear1_1/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[3][1]: 0.5 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[4][0]: '.*feed_forward/linear2/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[4][1]: 0.5 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[5][0]: '.*feed_forward/wi_0_weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[5][1]: 0.5 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[6][0]: '.*feed_forward/wi_1_weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[6][1]: 0.5 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[7][0]: '.*feed_forward/wo_weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[7][1]: 0.5 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.01 +learner.optimizer.args[1].update_schedule.alpha: 0.005 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 250000 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 5000 +learner.optimizer.args[1].weight_decay: 0.0001 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 250000 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v5p-(1024|2048)' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 16 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 16 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[1][0]: 'tpu-v6e-256' +mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 16 +mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 16 +mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' +mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: 1 +mesh_shape[2]: 16 +mesh_shape[3]: -1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 1536 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0][0]: 'expert' +model.decoder.emb.token_emb.param_partition_spec[0][1]: 'fsdp' +model.decoder.emb.token_emb.param_partition_spec[0][2]: 'seq' +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.StackedTransformerLayer' +model.decoder.transformer.layer.layer[0].feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.layer[0].feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.layer[0].feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[0].feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.layer[0].feed_forward.hidden_dim.round_up_to_multiples_of: 128 +model.decoder.transformer.layer.layer[0].feed_forward.hidden_dim.scale: 4 +model.decoder.transformer.layer.layer[0].feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.bias: False +model.decoder.transformer.layer.layer[0].feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.bias: False +model.decoder.transformer.layer.layer[0].feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.layer[0].feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[0].feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[0].feed_forward.structure: 'hybridnorm' +model.decoder.transformer.layer.layer[0].klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.layer[0].remat_spec['prevent_cse']: False +model.decoder.transformer.layer.layer[0].remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.layer[0].remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.layer[0].remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.layer[0].remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.layer[0].remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.layer[0].self_attention.attention.causal: True +model.decoder.transformer.layer.layer[0].self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.num_kv_heads: 12 +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.num_heads: 12 +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.layer[0].self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[0].self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.layer[0].self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[0].self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[0].self_attention.structure: 'prenorm' +model.decoder.transformer.layer.layer[1].feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.layer[1].feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['me'][0]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['me'][1]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['emh'][0]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['emh'][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['emh'][2]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ehm'][0]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ehm'][1]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ehm'][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][1]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][3]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][1]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][3]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][4]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][1]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][3]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][4]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][1]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][2]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][3]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][4]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][1]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][3]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][4]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[1].feed_forward.gating.eval_capacity_factor: 2.0 +model.decoder.transformer.layer.layer[1].feed_forward.gating.gating_logit_cap: 0.0 +model.decoder.transformer.layer.layer[1].feed_forward.gating.klass: 'axlearn.common.mixture_of_experts.Top2Gating' +model.decoder.transformer.layer.layer[1].feed_forward.gating.mask_dtype: 'jax.numpy.int32' +model.decoder.transformer.layer.layer[1].feed_forward.gating.train_capacity_factor: 2.0 +model.decoder.transformer.layer.layer[1].feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.layer[1].feed_forward.hidden_dim.round_up_to_multiples_of: 128 +model.decoder.transformer.layer.layer[1].feed_forward.hidden_dim.scale: 4 +model.decoder.transformer.layer.layer[1].feed_forward.input_dim: 1536 +model.decoder.transformer.layer.layer[1].feed_forward.klass: 'axlearn.common.mixture_of_experts.TransformerFeedForwardMoE' +model.decoder.transformer.layer.layer[1].feed_forward.load_balance_loss_weight: 0.01 +model.decoder.transformer.layer.layer[1].feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].feed_forward.num_experts: 128 +model.decoder.transformer.layer.layer[1].feed_forward.num_groups: 2 +model.decoder.transformer.layer.layer[1].feed_forward.outer_batch: -1 +model.decoder.transformer.layer.layer[1].feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.layer[1].feed_forward.router_z_loss_weight: 0.0 +model.decoder.transformer.layer.layer[1].feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[1].feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[1].feed_forward.structure: 'hybridnorm' +model.decoder.transformer.layer.layer[1].klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.layer[1].remat_spec['prevent_cse']: False +model.decoder.transformer.layer.layer[1].remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.layer[1].remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.layer[1].remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.layer[1].remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.layer[1].remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.layer[1].self_attention.attention.causal: True +model.decoder.transformer.layer.layer[1].self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.num_kv_heads: 12 +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.num_heads: 12 +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.layer[1].self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[1].self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.layer[1].self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[1].self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[1].self_attention.structure: 'prenorm' +model.decoder.transformer.layer.num_layers: 2 +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy']: 'jax._src.ad_checkpoint.dots_saveable' +model.decoder.transformer.num_layers: 6 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['lm'].z_loss_scale: 1e-06 +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base_init.txt new file mode 100644 index 000000000..c6255c7d1 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base_init.txt @@ -0,0 +1,23 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 1536], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer0/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(1536, 36, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/layer0/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(1536, 12, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_query/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_key/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/feed_forward/prenorm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/feed_forward/postnorm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(1536, 6144), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(1536, 6144), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer0/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(6144, 1536), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer1/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(1536, 36, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/layer1/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(1536, 12, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_query/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_key/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/feed_forward/gate_weight: normal(0, 1.0 / fan_in), shape=(1536, 128), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer1/feed_forward/wo_weight: normal(0, 1.0 / fan_in), shape=(128, 6144, 1536), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=0) +decoder/transformer/repeat/layer/layer1/feed_forward/wi_0_weight: normal(0, 1.0 / fan_in), shape=(128, 1536, 6144), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=0) +decoder/transformer/repeat/layer/layer1/feed_forward/wi_1_weight: normal(0, 1.0 / fan_in), shape=(128, 1536, 6144), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=0) +decoder/transformer/repeat/layer/layer1/feed_forward/prenorm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/feed_forward/postnorm/scale: constant(1.0) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base_regularizer.txt new file mode 100644 index 000000000..f218e7067 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Base_regularizer.txt @@ -0,0 +1,24 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/postnorm/scale: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/prenorm/scale: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_key/norm/scale: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_query/norm/scale: 1 +decoder/transformer/repeat/layer/layer0/self_attention/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/gate_weight: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/postnorm/scale: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/prenorm/scale: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/wi_0_weight: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/wi_1_weight: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/wo_weight: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_key/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_query/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Large.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Large.txt new file mode 100644 index 000000000..0c716304a --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Large.txt @@ -0,0 +1,508 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 5000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 250000 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 250000 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 25000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 250000 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 25000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 1024 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].adam_update_transformation.fn: 'axlearn.common.optimizers.scale_update_per_param' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.default_scale: 1.0 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.description: 'scale_by_mup_simple' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.fn: 'axlearn.common.optimizers.per_param_scale_by_path' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[0][0]: '.*attention/o_proj/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[0][1]: 0.375 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[1][0]: '.*attention/i_proj/i_proj/qkv_proj/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[1][1]: 0.375 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[2][0]: '.*feed_forward/linear1_0/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[2][1]: 0.375 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[3][0]: '.*feed_forward/linear1_1/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[3][1]: 0.375 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[4][0]: '.*feed_forward/linear2/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[4][1]: 0.375 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[5][0]: '.*feed_forward/wi_0_weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[5][1]: 0.375 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[6][0]: '.*feed_forward/wi_1_weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[6][1]: 0.375 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[7][0]: '.*feed_forward/wo_weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[7][1]: 0.375 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.01 +learner.optimizer.args[1].update_schedule.alpha: 0.005 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 250000 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 5000 +learner.optimizer.args[1].weight_decay: 0.0001 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 250000 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v5p-(1024|2048)' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 16 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 16 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[1][0]: 'tpu-v6e-256-4' +mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 16 +mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 16 +mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' +mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[2][0]: 'tpu-v6e-256' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 16 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 16 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' +mesh_rules[2][1].config_modifiers[2].grad_acc_steps: 4 +mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' +mesh_rules[2][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: 1 +mesh_shape[2]: 16 +mesh_shape[3]: -1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 2048 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0][0]: 'expert' +model.decoder.emb.token_emb.param_partition_spec[0][1]: 'fsdp' +model.decoder.emb.token_emb.param_partition_spec[0][2]: 'seq' +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.StackedTransformerLayer' +model.decoder.transformer.layer.layer[0].feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.layer[0].feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.layer[0].feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[0].feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.layer[0].feed_forward.hidden_dim.round_up_to_multiples_of: 128 +model.decoder.transformer.layer.layer[0].feed_forward.hidden_dim.scale: 4 +model.decoder.transformer.layer.layer[0].feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.bias: False +model.decoder.transformer.layer.layer[0].feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.bias: False +model.decoder.transformer.layer.layer[0].feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.layer[0].feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[0].feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[0].feed_forward.structure: 'hybridnorm' +model.decoder.transformer.layer.layer[0].klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.layer[0].remat_spec['prevent_cse']: False +model.decoder.transformer.layer.layer[0].remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.layer[0].remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.layer[0].remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.layer[0].remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.layer[0].remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.layer[0].self_attention.attention.causal: True +model.decoder.transformer.layer.layer[0].self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.num_kv_heads: 16 +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.num_heads: 16 +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.layer[0].self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[0].self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.layer[0].self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[0].self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[0].self_attention.structure: 'prenorm' +model.decoder.transformer.layer.layer[1].feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.layer[1].feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['me'][0]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['me'][1]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['emh'][0]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['emh'][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['emh'][2]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ehm'][0]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ehm'][1]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ehm'][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][1]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][3]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][1]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][3]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][4]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][1]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][3]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][4]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][1]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][2]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][3]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][4]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][1]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][3]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][4]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[1].feed_forward.gating.eval_capacity_factor: 2.0 +model.decoder.transformer.layer.layer[1].feed_forward.gating.gating_logit_cap: 0.0 +model.decoder.transformer.layer.layer[1].feed_forward.gating.klass: 'axlearn.common.mixture_of_experts.Top2Gating' +model.decoder.transformer.layer.layer[1].feed_forward.gating.mask_dtype: 'jax.numpy.int32' +model.decoder.transformer.layer.layer[1].feed_forward.gating.train_capacity_factor: 2.0 +model.decoder.transformer.layer.layer[1].feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.layer[1].feed_forward.hidden_dim.round_up_to_multiples_of: 128 +model.decoder.transformer.layer.layer[1].feed_forward.hidden_dim.scale: 4 +model.decoder.transformer.layer.layer[1].feed_forward.input_dim: 2048 +model.decoder.transformer.layer.layer[1].feed_forward.klass: 'axlearn.common.mixture_of_experts.TransformerFeedForwardMoE' +model.decoder.transformer.layer.layer[1].feed_forward.load_balance_loss_weight: 0.01 +model.decoder.transformer.layer.layer[1].feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].feed_forward.num_experts: 128 +model.decoder.transformer.layer.layer[1].feed_forward.num_groups: 2 +model.decoder.transformer.layer.layer[1].feed_forward.outer_batch: -1 +model.decoder.transformer.layer.layer[1].feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.layer[1].feed_forward.router_z_loss_weight: 0.0 +model.decoder.transformer.layer.layer[1].feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[1].feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[1].feed_forward.structure: 'hybridnorm' +model.decoder.transformer.layer.layer[1].klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.layer[1].remat_spec['prevent_cse']: False +model.decoder.transformer.layer.layer[1].remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.layer[1].remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.layer[1].remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.layer[1].remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.layer[1].remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.layer[1].self_attention.attention.causal: True +model.decoder.transformer.layer.layer[1].self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.num_kv_heads: 16 +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.num_heads: 16 +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.layer[1].self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[1].self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.layer[1].self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[1].self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[1].self_attention.structure: 'prenorm' +model.decoder.transformer.layer.num_layers: 2 +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy']: 'jax._src.ad_checkpoint.dots_saveable' +model.decoder.transformer.num_layers: 12 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['lm'].z_loss_scale: 1e-06 +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Large_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Large_init.txt new file mode 100644 index 000000000..94e8354cd --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Large_init.txt @@ -0,0 +1,23 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer0/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/layer0/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 16, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_query/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_key/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/feed_forward/prenorm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/feed_forward/postnorm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer0/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer1/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/layer1/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 16, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_query/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_key/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/feed_forward/gate_weight: normal(0, 1.0 / fan_in), shape=(2048, 128), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer1/feed_forward/wo_weight: normal(0, 1.0 / fan_in), shape=(128, 8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=0) +decoder/transformer/repeat/layer/layer1/feed_forward/wi_0_weight: normal(0, 1.0 / fan_in), shape=(128, 2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=0) +decoder/transformer/repeat/layer/layer1/feed_forward/wi_1_weight: normal(0, 1.0 / fan_in), shape=(128, 2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=0) +decoder/transformer/repeat/layer/layer1/feed_forward/prenorm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/feed_forward/postnorm/scale: constant(1.0) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Large_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Large_regularizer.txt new file mode 100644 index 000000000..f218e7067 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-Large_regularizer.txt @@ -0,0 +1,24 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/postnorm/scale: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/prenorm/scale: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_key/norm/scale: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_query/norm/scale: 1 +decoder/transformer/repeat/layer/layer0/self_attention/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/gate_weight: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/postnorm/scale: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/prenorm/scale: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/wi_0_weight: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/wi_1_weight: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/wo_weight: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_key/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_query/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-XXL.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-XXL.txt new file mode 100644 index 000000000..66a86dcae --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-XXL.txt @@ -0,0 +1,469 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 5000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 250000 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 250000 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 25000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 250000 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 25000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 1024 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].adam_update_transformation.fn: 'axlearn.common.optimizers.scale_update_per_param' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.default_scale: 1.0 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.description: 'scale_by_mup_simple' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.fn: 'axlearn.common.optimizers.per_param_scale_by_path' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[0][0]: '.*attention/o_proj/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[0][1]: 0.09375 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[1][0]: '.*attention/i_proj/i_proj/qkv_proj/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[1][1]: 0.09375 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[2][0]: '.*feed_forward/linear1_0/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[2][1]: 0.09375 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[3][0]: '.*feed_forward/linear1_1/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[3][1]: 0.09375 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[4][0]: '.*feed_forward/linear2/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[4][1]: 0.09375 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[5][0]: '.*feed_forward/wi_0_weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[5][1]: 0.09375 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[6][0]: '.*feed_forward/wi_1_weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[6][1]: 0.09375 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[7][0]: '.*feed_forward/wo_weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[7][1]: 0.09375 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.01 +learner.optimizer.args[1].update_schedule.alpha: 0.005 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 250000 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 5000 +learner.optimizer.args[1].weight_decay: 0.0001 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 250000 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: 1 +mesh_shape[2]: 16 +mesh_shape[3]: -1 +mesh_shape[4]: 1 +mesh_shape[5]: 8 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 8192 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0][0]: 'expert' +model.decoder.emb.token_emb.param_partition_spec[0][1]: 'fsdp' +model.decoder.emb.token_emb.param_partition_spec[0][2]: 'seq' +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.StackedTransformerLayer' +model.decoder.transformer.layer.layer[0].feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.layer[0].feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.layer[0].feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[0].feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.layer[0].feed_forward.hidden_dim.round_up_to_multiples_of: 128 +model.decoder.transformer.layer.layer[0].feed_forward.hidden_dim.scale: 2.5 +model.decoder.transformer.layer.layer[0].feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.bias: False +model.decoder.transformer.layer.layer[0].feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.bias: False +model.decoder.transformer.layer.layer[0].feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.layer[0].feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[0].feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[0].feed_forward.structure: 'hybridnorm' +model.decoder.transformer.layer.layer[0].klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.layer[0].remat_spec['prevent_cse']: False +model.decoder.transformer.layer.layer[0].remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.layer[0].remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.layer[0].remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.layer[0].remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.layer[0].remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.layer[0].self_attention.attention.causal: True +model.decoder.transformer.layer.layer[0].self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.num_heads: 64 +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.layer[0].self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[0].self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.layer[0].self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[0].self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[0].self_attention.structure: 'prenorm' +model.decoder.transformer.layer.layer[1].feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.layer[1].feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['me'][0]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['me'][1]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['emh'][0]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['emh'][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['emh'][2]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ehm'][0]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ehm'][1]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ehm'][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][1]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][3]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][1]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][3]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][4]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][1]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][3]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][4]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][1]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][2]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][3]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][4]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][1]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][3]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][4]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[1].feed_forward.gating.eval_capacity_factor: 2.0 +model.decoder.transformer.layer.layer[1].feed_forward.gating.gating_logit_cap: 0.0 +model.decoder.transformer.layer.layer[1].feed_forward.gating.klass: 'axlearn.common.mixture_of_experts.Top2Gating' +model.decoder.transformer.layer.layer[1].feed_forward.gating.mask_dtype: 'jax.numpy.int32' +model.decoder.transformer.layer.layer[1].feed_forward.gating.train_capacity_factor: 2.0 +model.decoder.transformer.layer.layer[1].feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.layer[1].feed_forward.hidden_dim.round_up_to_multiples_of: 128 +model.decoder.transformer.layer.layer[1].feed_forward.hidden_dim.scale: 2.5 +model.decoder.transformer.layer.layer[1].feed_forward.input_dim: 8192 +model.decoder.transformer.layer.layer[1].feed_forward.klass: 'axlearn.common.mixture_of_experts.TransformerFeedForwardMoE' +model.decoder.transformer.layer.layer[1].feed_forward.load_balance_loss_weight: 0.01 +model.decoder.transformer.layer.layer[1].feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].feed_forward.num_experts: 64 +model.decoder.transformer.layer.layer[1].feed_forward.num_groups: 2 +model.decoder.transformer.layer.layer[1].feed_forward.outer_batch: -1 +model.decoder.transformer.layer.layer[1].feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.layer[1].feed_forward.router_z_loss_weight: 0.0 +model.decoder.transformer.layer.layer[1].feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[1].feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[1].feed_forward.structure: 'hybridnorm' +model.decoder.transformer.layer.layer[1].klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.layer[1].remat_spec['prevent_cse']: False +model.decoder.transformer.layer.layer[1].remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.layer[1].remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.layer[1].remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.layer[1].remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.layer[1].remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.layer[1].self_attention.attention.causal: True +model.decoder.transformer.layer.layer[1].self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.num_heads: 64 +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.layer[1].self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[1].self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.layer[1].self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[1].self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[1].self_attention.structure: 'prenorm' +model.decoder.transformer.layer.num_layers: 2 +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy']: 'jax._src.ad_checkpoint.dots_saveable' +model.decoder.transformer.num_layers: 12 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['lm'].z_loss_scale: 1e-06 +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-XXL_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-XXL_init.txt new file mode 100644 index 000000000..f9747bac8 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-XXL_init.txt @@ -0,0 +1,23 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer0/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/layer0/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_query/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_key/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/feed_forward/prenorm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/feed_forward/postnorm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 20480), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 20480), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer0/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(20480, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer1/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/layer1/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_query/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_key/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/feed_forward/gate_weight: normal(0, 1.0 / fan_in), shape=(8192, 64), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer1/feed_forward/wo_weight: normal(0, 1.0 / fan_in), shape=(64, 20480, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=0) +decoder/transformer/repeat/layer/layer1/feed_forward/wi_0_weight: normal(0, 1.0 / fan_in), shape=(64, 8192, 20480), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=0) +decoder/transformer/repeat/layer/layer1/feed_forward/wi_1_weight: normal(0, 1.0 / fan_in), shape=(64, 8192, 20480), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=0) +decoder/transformer/repeat/layer/layer1/feed_forward/prenorm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/feed_forward/postnorm/scale: constant(1.0) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-XXL_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-XXL_regularizer.txt new file mode 100644 index 000000000..f218e7067 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-Switch-XXL_regularizer.txt @@ -0,0 +1,24 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/postnorm/scale: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/prenorm/scale: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_key/norm/scale: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_query/norm/scale: 1 +decoder/transformer/repeat/layer/layer0/self_attention/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/gate_weight: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/postnorm/scale: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/prenorm/scale: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/wi_0_weight: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/wi_1_weight: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/wo_weight: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_key/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_query/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-test.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-test.txt new file mode 100644 index 000000000..67896e3d1 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-test.txt @@ -0,0 +1,401 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 3000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3000 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3000 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 25000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 16 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 64 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3000 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 25000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 16 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 64 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 16 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 64 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].adam_update_transformation.fn: 'axlearn.common.optimizers.scale_update_per_param' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.default_scale: 1.0 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.description: 'scale_by_mup_simple' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.fn: 'axlearn.common.optimizers.per_param_scale_by_path' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[0][0]: '.*attention/o_proj/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[0][1]: 1.0 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[1][0]: '.*attention/i_proj/i_proj/qkv_proj/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[1][1]: 1.0 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[2][0]: '.*feed_forward/linear1_0/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[2][1]: 1.0 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[3][0]: '.*feed_forward/linear1_1/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[3][1]: 1.0 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[4][0]: '.*feed_forward/linear2/weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[4][1]: 1.0 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[5][0]: '.*feed_forward/wi_0_weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[5][1]: 1.0 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[6][0]: '.*feed_forward/wi_1_weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[6][1]: 1.0 +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[7][0]: '.*feed_forward/wo_weight' +learner.optimizer.args[1].adam_update_transformation.per_param_scale.scale_by_path[7][1]: 1.0 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.01 +learner.optimizer.args[1].update_schedule.alpha: 0.005 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3000 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.000316 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3000 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 8 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0][0]: 'expert' +model.decoder.emb.token_emb.param_partition_spec[0][1]: 'fsdp' +model.decoder.emb.token_emb.param_partition_spec[0][2]: 'seq' +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.StackedTransformerLayer' +model.decoder.transformer.layer.layer[0].feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.layer[0].feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.layer[0].feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[0].feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.layer[0].feed_forward.hidden_dim.round_up_to_multiples_of: 16 +model.decoder.transformer.layer.layer[0].feed_forward.hidden_dim.scale: 2.6666666666666665 +model.decoder.transformer.layer.layer[0].feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.bias: False +model.decoder.transformer.layer.layer[0].feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.bias: False +model.decoder.transformer.layer.layer[0].feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.layer[0].feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.layer[0].feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[0].feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[0].feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.layer[0].klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.layer[0].remat_spec['prevent_cse']: False +model.decoder.transformer.layer.layer[0].remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.layer[0].remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.layer[0].remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.layer[0].remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.layer[0].remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.layer[0].self_attention.attention.causal: True +model.decoder.transformer.layer.layer[0].self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.input_linear.num_kv_heads: 2 +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.layer[0].self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].self_attention.attention.key_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.layer[0].self_attention.attention.num_heads: 4 +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[0].self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].self_attention.attention.query_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[0].self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.layer[0].self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[0].self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.layer[0].self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[0].self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[0].self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[0].self_attention.structure: 'prenorm' +model.decoder.transformer.layer.layer[1].feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.layer[1].feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['me'][0]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['me'][1]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['emh'][0]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['emh'][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['emh'][2]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ehm'][0]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ehm'][1]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ehm'][2]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][1]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsm'][3]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][1]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][3]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogsec'][4]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][1]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][3]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegcm'][4]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][1]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][2]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][3]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['ogecm'][4]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][0][0]: 'data' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][1]: 'expert' +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][2]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][3]: None +model.decoder.transformer.layer.layer[1].feed_forward.dim_to_mesh_axis_map['oegch'][4]: 'model' +model.decoder.transformer.layer.layer[1].feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[1].feed_forward.gating.eval_capacity_factor: 2.0 +model.decoder.transformer.layer.layer[1].feed_forward.gating.gating_logit_cap: 0.0 +model.decoder.transformer.layer.layer[1].feed_forward.gating.klass: 'axlearn.common.mixture_of_experts.Top2Gating' +model.decoder.transformer.layer.layer[1].feed_forward.gating.mask_dtype: 'jax.numpy.int32' +model.decoder.transformer.layer.layer[1].feed_forward.gating.train_capacity_factor: 2.0 +model.decoder.transformer.layer.layer[1].feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.layer[1].feed_forward.hidden_dim.round_up_to_multiples_of: 16 +model.decoder.transformer.layer.layer[1].feed_forward.hidden_dim.scale: 2.6666666666666665 +model.decoder.transformer.layer.layer[1].feed_forward.input_dim: 8 +model.decoder.transformer.layer.layer[1].feed_forward.klass: 'axlearn.common.mixture_of_experts.TransformerFeedForwardMoE' +model.decoder.transformer.layer.layer[1].feed_forward.load_balance_loss_weight: 0.01 +model.decoder.transformer.layer.layer[1].feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].feed_forward.num_experts: 8 +model.decoder.transformer.layer.layer[1].feed_forward.num_groups: 2 +model.decoder.transformer.layer.layer[1].feed_forward.outer_batch: 1 +model.decoder.transformer.layer.layer[1].feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.layer[1].feed_forward.router_z_loss_weight: 0.0 +model.decoder.transformer.layer.layer[1].feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[1].feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[1].feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.layer[1].klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.layer[1].remat_spec['prevent_cse']: False +model.decoder.transformer.layer.layer[1].remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.layer[1].remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.layer[1].remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.layer[1].remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.layer[1].remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.layer[1].self_attention.attention.causal: True +model.decoder.transformer.layer.layer[1].self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.input_linear.num_kv_heads: 2 +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.layer[1].self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].self_attention.attention.key_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.layer[1].self_attention.attention.num_heads: 4 +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.layer[1].self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].self_attention.attention.query_scale.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.layer[1].self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.layer[1].self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.layer[1].self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.layer[1].self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.layer[1].self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.layer[1].self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.layer[1].self_attention.structure: 'prenorm' +model.decoder.transformer.layer.num_layers: 2 +model.decoder.transformer.num_layers: 2 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['lm'].z_loss_scale: 1e-06 +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-test_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-test_init.txt new file mode 100644 index 000000000..7c5711907 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-test_init.txt @@ -0,0 +1,21 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer0/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/layer0/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_query/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_key/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer0/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer1/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/layer1/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_query/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_key/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/layer1/feed_forward/gate_weight: normal(0, 1.0 / fan_in), shape=(8, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/layer1/feed_forward/wo_weight: normal(0, 1.0 / fan_in), shape=(8, 32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=0) +decoder/transformer/repeat/layer/layer1/feed_forward/wi_0_weight: normal(0, 1.0 / fan_in), shape=(8, 8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=0) +decoder/transformer/repeat/layer/layer1/feed_forward/wi_1_weight: normal(0, 1.0 / fan_in), shape=(8, 8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=0) +decoder/transformer/repeat/layer/layer1/feed_forward/norm/scale: constant(1.0) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-test_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-test_regularizer.txt new file mode 100644 index 000000000..a1db1529b --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/envy-test_regularizer.txt @@ -0,0 +1,22 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/layer0/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_key/norm/scale: 1 +decoder/transformer/repeat/layer/layer0/self_attention/attention/scale_query/norm/scale: 1 +decoder/transformer/repeat/layer/layer0/self_attention/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/gate_weight: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/wi_0_weight: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/wi_1_weight: 1 +decoder/transformer/repeat/layer/layer1/feed_forward/wo_weight: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_key/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/self_attention/attention/scale_query/norm/scale: 1 +decoder/transformer/repeat/layer/layer1/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-fp8.txt index 3e8f4e287..e3ecabac0 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-fp8.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt index 30dff4e16..acce7c015 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-fp8.txt index e38b951d2..1a55d4042 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-fp8.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt index 600ce8769..80415b114 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-fp8.txt index ba0ffc057..e8ddc66fd 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-fp8.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt index bd1629eb8..985ee4fea 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-fp8.txt index 02afd1527..7a656dda6 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-fp8.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt index 7c81b2395..c3e8440e5 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-fp8.txt index 782439c3e..850d0b9ff 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-fp8.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt index 56965d7e6..d2f92d165 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-fp8.txt index f060419d0..3852969d2 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-fp8.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash-fp8.txt index 5d3a36d70..6992e10d5 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash-fp8.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt index 5cd5b2d0b..11717ee53 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-fp8.txt index b95b5db75..d9105f44f 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-fp8.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt index 2359fd678..2f0340ab8 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt index 2d2cd4237..6b6903c4b 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt @@ -135,9 +135,7 @@ mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[1][0]: 'tpu-v5p-.*' mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -165,11 +163,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[2][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' @@ -183,11 +177,7 @@ mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[3][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[3][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[3][1].config_modifiers[3].grad_acc_steps: 4 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-fp8-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-fp8-single-host.txt index bcf68ba02..8f70fd2af 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-fp8-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-fp8-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-fp8.txt index 78b32d5b5..f0a186582 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-fp8.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt index a9ae5696a..4c96449b2 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt index 36ef37005..7b4432307 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-fp8-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-fp8-single-host.txt index f1ec1e779..a9754239b 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-fp8-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-fp8-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-fp8.txt index c5315baeb..669d632ee 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-fp8.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt index 888241cf1..48d33cab3 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt index 93d656948..2a85cace6 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-fp8-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-fp8-single-host.txt index f140e4897..1e66e6f74 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-fp8-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-fp8-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-fp8.txt index b22f0133b..ec12438a8 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-fp8.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt index 80daccb91..c185e8ad1 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt index a120b9358..493ee2c4e 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-fp8-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-fp8-single-host.txt index a4ccb5664..6751681c4 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-fp8-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-fp8-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-fp8.txt index 9c49dc517..803784336 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-fp8.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt index 2039b34d7..4c60b7979 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt index 9b245f87d..eeaca3798 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-fp8-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-fp8-single-host.txt index 4f2e32abf..98ce91744 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-fp8-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-fp8-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-fp8.txt index 76f6c56e7..655bc6f7c 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-fp8.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt index 7f18e9684..8e3473342 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt index 98f63126c..b444f4a54 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-fp8-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-fp8-single-host.txt index f7e425ce3..a8a568ada 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-fp8-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-fp8-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-fp8.txt index a69f79636..fb83bf99a 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-fp8.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt index 5a795e774..c40eb3471 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt index 7ae0b85d3..c82fcdc96 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' @@ -185,11 +181,7 @@ mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_attention_proj_policy' mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' mesh_rules[4][1].config_modifiers[2].tpu_block_size: 1024 mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-fp8-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-fp8-single-host.txt index 7c37e5b52..304563976 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-fp8-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-fp8-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-fp8.txt index c253d617f..e918ab3cd 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-fp8.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt index 7c37e5b52..304563976 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt index c253d617f..e918ab3cd 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-fp8-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-fp8-single-host.txt index 5768568e9..4256af27a 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-fp8-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-fp8-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-fp8.txt index 022428eaa..bc2ea3417 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-fp8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-fp8.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt index 5768568e9..4256af27a 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt index 022428eaa..bc2ea3417 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt @@ -142,9 +142,7 @@ mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -159,9 +157,7 @@ mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.experiments.text.gpt.fuji.offload_dots_saveable_policy' mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_rules[3][0]: 'tpu-v5litepod-256-4' mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' diff --git a/axlearn/experiments/text/gpt/c4_trainer.py b/axlearn/experiments/text/gpt/c4_trainer.py index c9dcdc989..ce247e290 100644 --- a/axlearn/experiments/text/gpt/c4_trainer.py +++ b/axlearn/experiments/text/gpt/c4_trainer.py @@ -44,7 +44,7 @@ from axlearn.common.input_lm import lm_text_preprocessor from axlearn.common.utils import get_data_dir from axlearn.experiments.text.common import DataMixtureComponent, vocab -from axlearn.experiments.text.gpt import fuji, gspmd +from axlearn.experiments.text.gpt import envy, fuji, gspmd from axlearn.experiments.text.gpt.common import mixture_train_input_source, tfds_input from axlearn.experiments.text.gpt.vocabulary_fuji_v3 import FujiV3Vocabulary from axlearn.experiments.trainer_config_utils import TrainerConfigFn @@ -105,4 +105,5 @@ def named_trainer_configs() -> dict[str, TrainerConfigFn]: config_map = {} config_map.update(fuji.trainer_configs(_train_input_source, _eval_input_sources)) config_map.update(gspmd.trainer_configs(_train_input_source, _eval_input_sources)) + config_map.update(envy.trainer_configs(_train_input_source, _eval_input_sources)) return config_map diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 57d606dab..4a96ffa19 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -12,7 +12,7 @@ import math from collections.abc import Sequence -from typing import Optional, Protocol, Union +from typing import Literal, Optional, Protocol, Union import jax.numpy as jnp import tensorflow as tf @@ -57,6 +57,7 @@ from axlearn.common.flash_attention.layer import FlashAttention from axlearn.common.input_dispatch import InputDispatcher from axlearn.common.layers import BaseNormalizationLayer, set_bias_recursively, set_norm_recursively +from axlearn.common.mixture_of_experts import TransformerFeedForwardMoE from axlearn.common.optimizer_base import PartitionedGradientTransformation from axlearn.common.param_init import PARAM_REGEXP_WEIGHT, DefaultInitializer, WeightInitializer from axlearn.common.summary_writer import BaseWriter @@ -229,6 +230,8 @@ def model_config( atten_logit_cap: Optional[float] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, + ffn_layer_types: Optional[Sequence[Literal["dense", "sparse"]]] = None, + expert_cfg: TransformerFeedForwardMoE = TransformerFeedForwardMoE.default_config(), ) -> causal_lm.Model.Config: """Returns an LM model config based on the given hyperparams. @@ -262,14 +265,15 @@ def model_config( remat_offload_dst: Destination of remat checkptoing offloading. pad_token_id: Int ID of the inputs to be masked for self-attention. eos_token_id: Int ID of the end of sequence token id. + ffn_layer_types: The types of layers in the FFN. If None, defaults to "dense". + Otherwise, `ffn_layer_types` should be one of [dense, sparse]. + expert_cfg: The expert config for the MoE FFN. This is only used if at least one layer + type is sparse. Returns: A causal LM config. """ - # Feed-forward. - layer_cfg.feed_forward.activation = activation_fn - layer_cfg.feed_forward.hidden_dim = ffn_dim - layer_cfg.feed_forward.structure = ffn_structure + # First configure the base layer_cfg. # Attention. if attention_cfg is not None: layer_cfg.self_attention.attention = attention_cfg @@ -283,8 +287,58 @@ def model_config( layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap if issubclass(stack_cfg.klass, (RepeatedTransformerLayer, StackedTransformerLayer)): update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) - # Stack. - transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg) + + # Shard some FFN and attention weights over multiple axes. + batch_axis_names = ("data", "expert", "fsdp") + set_double_shard_weights_config( + layer_cfg, + batch_axis_names=batch_axis_names, + fsdp_axis_names=("expert", "fsdp", "seq"), + tp_axis_names="model", + seq_axis_names="seq", + ) + + def config_dense(cfg: TransformerLayer.Config) -> TransformerLayer.Config: + cfg = layer_cfg.clone() + cfg.feed_forward.activation = activation_fn + cfg.feed_forward.hidden_dim = ffn_dim + cfg.feed_forward.structure = ffn_structure + return cfg + + def config_sparse(cfg: TransformerLayer.Config) -> TransformerLayer.Config: + cfg = layer_cfg.clone() + cfg.feed_forward = expert_cfg + cfg.feed_forward.activation = activation_fn + cfg.feed_forward.hidden_dim = ffn_dim + cfg.feed_forward.structure = ffn_structure + return cfg + + ffn_layer_type_to_config = { + "dense": config_dense, + "sparse": config_sparse, + } + if ffn_layer_types is None: + lm_layer_cfg = config_dense(layer_cfg) + else: + lm_layer_cfg = [ + ffn_layer_type_to_config[layer_type](layer_cfg) for layer_type in ffn_layer_types + ] + + # Single layer repeated num_layers times. + if not isinstance(lm_layer_cfg, Sequence): + transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=lm_layer_cfg) + elif len(lm_layer_cfg) == 1: + # No need to stack together a single layer. + transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=lm_layer_cfg[0]) + else: + num_layers_cfgs = len(lm_layer_cfg) + # Stack together the layers. + transformer_cfg = stack_cfg.set( + num_layers=num_layers // num_layers_cfgs, + layer=StackedTransformerLayer.default_config().set( + num_layers=num_layers_cfgs, layer=list(lm_layer_cfg) + ), + ) decoder_cfg = Decoder.default_config().set( transformer=transformer_cfg, attention_mask=attention_mask, @@ -306,7 +360,8 @@ def model_config( ) } ) - batch_axis_names = ("data", "expert", "fsdp") + + # A few more model-level settings. cfg: causal_lm.Model.Config = causal_lm.Model.default_config().set( decoder=decoder_cfg, param_init=model_param_init, @@ -315,14 +370,6 @@ def model_config( if z_loss_scale: cfg.metrics = causal_lm.metrics_config(z_loss_scale=z_loss_scale) cfg.dtype = jnp.float32 - # Shard some FFN and attention weights over multiple axes. - set_double_shard_weights_config( - cfg.decoder.transformer.layer, - batch_axis_names=batch_axis_names, - fsdp_axis_names=("expert", "fsdp", "seq"), - tp_axis_names="model", - seq_axis_names="seq", - ) cfg.decoder.logits_partition_spec = (batch_axis_names, "seq", "model") set_bias_recursively(cfg, False) set_norm_recursively(cfg, normalization) @@ -697,10 +744,11 @@ def config_fn() -> InstantiableConfig: } ), ) + cfg.evalers = {} for name, evaler_cfg in evalers.items(): - evaler_cfg.input.input_dispatcher.global_logical_batch_size = ( - eval_batch_size or train_batch_size + evaler_cfg.input.input_dispatcher = InputDispatcher.default_config().set( + global_logical_batch_size=eval_batch_size or train_batch_size ) evaler_cfg.set( eval_policy=config_for_function(eval_every_n_steps_policy).set( diff --git a/axlearn/experiments/text/gpt/envy.py b/axlearn/experiments/text/gpt/envy.py new file mode 100644 index 000000000..c18c451b9 --- /dev/null +++ b/axlearn/experiments/text/gpt/envy.py @@ -0,0 +1,540 @@ +# Copyright © 2025 Apple Inc. + +"""Utilities to set up the 'Envy' MoE style model trainer configs. + +Add MoE style model configs for the GPT model class. +- SwitchTransformer . +- Apple MoE + +We follow most of the practice in switch-transformer for MoE, however there are some key +differences: +- We do not use a T5-style model, but a GPT-style decoder-only model. +- We do not follow the optimizer settings in the paper, instead follow practice from Fuji and Gala. +- We use the same tokenizer as Fuji model classes. +- We increase the hidden dimension per head to 128 for better use of tensorcore in both TPU and GPU. +- We increase the sequence length to 8k instead of 512 in most of the T5 models, + and increase global tokens/batch to 8M instead of 1M. +- We use rotary positional embeddings instead of the relative positional embeddings. +- We retain the values for num_heads, num_layers, and num_experts as specified in the paper, + aside from these and the adjusted hyperparameters mentioned above, the remaining + hyperparameters were set arbitrarily. + +Architecture names follow apple varieties: Fuji, Gala, etc. +""" + +import functools +from typing import Any, Literal, Sequence, Union + +from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies + +from axlearn.common import causal_lm, config +from axlearn.common.attention import ( + FusedGroupedQKVLinear, + GroupedQueryAttention, + RoFormerQKVLinear, + ScaleKey, + ScaleQuery, + TransformerLayer, +) +from axlearn.common.base_layer import RematSpec +from axlearn.common.embedding import TransformerTextEmbeddings +from axlearn.common.layers import RMSNorm +from axlearn.common.mixture_of_experts import TransformerFeedForwardMoE, get_outer_batch_from_mesh +from axlearn.common.trainer import SpmdTrainer +from axlearn.common.trainer_config_modifier import ( + ChainConfigModifier, + GradientAccumulationModifier, + MeshShapeModifier, + RematSpecModifier, +) +from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec +from axlearn.experiments.text.gpt.common import ( + MESH_AXIS_NAMES, + SourceBuilder, + adamw_decoupled_learner_config, + evaler_config_dict, + flash_attention_config, + get_trainer_config_fn, + make_config_name, + mesh_shape_from_axes, +) +from axlearn.experiments.text.gpt.common import model_config as common_model_config +from axlearn.experiments.text.gpt.common import ( + mup_simple_adam_update_transformation, + scaled_hidden_dim, +) +from axlearn.experiments.text.gpt.fuji import offload_attention_proj_policy +from axlearn.experiments.trainer_config_utils import TrainerConfigFn + +MODEL_SIZES = ("test", "Switch-Base", "Switch-Large", "Switch-XXL") + +NUM_EXPERTS = { + "test": 8, + "Switch-Base": 128, + "Switch-Large": 128, + "Switch-XXL": 64, +} + +# T5 uses 32128 vocab size, we make it 32768 for simplicity. +VOCAB_SIZE = 32 * 1024 + +MAX_SEQUENCE_LENGTH = { + "test": 8192, + "Switch-Base": 8192, + "Switch-Large": 8192, + "Switch-XXL": 8192, +} + +_BASE_MODEL_HIDDEN_DIM = 768 + +MOE_OUTER_BATCH_AXIS_NAMES = ("data", "fsdp") + +MOE_DIM_TO_MESH_AXIS_MAP = { + "me": PartitionSpec(None, None), + "emh": PartitionSpec("expert", "fsdp", "model"), + "ehm": PartitionSpec("expert", "model", "fsdp"), + "ogsm": PartitionSpec(MOE_OUTER_BATCH_AXIS_NAMES, "expert", None, "model"), + # Dispatch and combine tensors. + "ogsec": PartitionSpec(MOE_OUTER_BATCH_AXIS_NAMES, None, None, "expert", None), + "oegcm": PartitionSpec(MOE_OUTER_BATCH_AXIS_NAMES, "expert", None, None, "model"), + "ogecm": PartitionSpec(MOE_OUTER_BATCH_AXIS_NAMES, None, "expert", None, "model"), + "oegch": PartitionSpec(MOE_OUTER_BATCH_AXIS_NAMES, "expert", None, None, "model"), +} + + +def common_trainer_kwargs() -> dict[str, Any]: + """Returns kwargs that are common to all configs.""" + return { + "model_kwargs": { + "z_loss_scale": 1e-6, + }, + "learner_kwargs": { + "peak_lr": 1e-2, + "alpha": 1 / 200.0, + "weight_decay": 3.16e-4, + }, + "save_every_n_steps": 5000, + "keep_every_n_steps": 5000, + "eval_every_n_steps": 25_000, + "mesh_shape": mesh_shape_from_axes(data=-1), + } + + +def get_trainer_kwargs( + model_size: str, + *, + vocab_size: int, + max_sequence_length: int, + flash_attention: bool, +) -> dict[str, Any]: + """Construct default trainer kwargs given a model size.""" + tokens_per_batch = 8 * (1024**2) # 8M tokens. + + # pylint: disable=use-dict-literal + if model_size == "test": + trainer_kwargs = dict( + model_kwargs=dict( + num_layers=4, + hidden_dim=8, + ffn_dim=scaled_hidden_dim(scale=8 / 3, round_up_to_multiples_of=16), + num_heads=4, + num_kv_heads=2, + vocab_size=32, + num_experts=8, + train_capacity_factor=2.0, + num_groups=2, + ffn_layer_types=[ + "dense", + "sparse", + ], + ), + learner_kwargs=dict(), + max_sequence_length=64, + train_batch_size=16, + max_step=3000, + mesh_shape=mesh_shape_from_axes(data=-1), + ) + elif model_size == "Switch-Base": + # Num of parameters: 30B. + trainer_kwargs = dict( + model_kwargs=dict( + num_layers=12, + hidden_dim=12 * 128, + ffn_dim=scaled_hidden_dim(scale=4, round_up_to_multiples_of=128), + num_heads=12, + num_kv_heads=12, + num_experts=NUM_EXPERTS[model_size], + train_capacity_factor=2.0, + num_groups=2, + ffn_structure="hybridnorm", + # MoE layer every 2 layers. + ffn_layer_types=[ + "dense", + "sparse", + ], + ), + learner_kwargs=dict(peak_lr=0.01, weight_decay=1e-4, lr_warmup_steps=5_000), + max_sequence_length=max_sequence_length, + train_batch_size=tokens_per_batch // max_sequence_length, # 8M tokens. + max_step=250_000, + mesh_shape=mesh_shape_from_axes(fsdp=-1, expert=16), + mesh_rules=( + ( + "tpu-v5p-(1024|2048)", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(data=-1, expert=16, fsdp=16) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=jax_remat_policies.dots_saveable, + ), + } + ), + ], + ), + ), + ( + "tpu-v6e-256", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(data=-1, expert=16, fsdp=16) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=offload_attention_proj_policy, + ), + } + ), + ], + ), + ), + ), + ) + elif model_size == "Switch-Large": + # Num of parameters: 104B. + trainer_kwargs = dict( + model_kwargs=dict( + num_layers=24, + hidden_dim=16 * 128, + ffn_dim=scaled_hidden_dim(scale=4, round_up_to_multiples_of=128), + num_heads=16, + num_kv_heads=16, + num_experts=NUM_EXPERTS[model_size], + train_capacity_factor=2.0, + num_groups=2, + ffn_structure="hybridnorm", + # MoE layer every 2 layers. + ffn_layer_types=[ + "dense", + "sparse", + ], + ), + learner_kwargs=dict(peak_lr=0.01, weight_decay=1e-4, lr_warmup_steps=5_000), + max_sequence_length=max_sequence_length, + train_batch_size=tokens_per_batch // max_sequence_length, # 8M tokens. + max_step=250_000, # Most of the evals were done at 100k steps in the paper. + mesh_shape=mesh_shape_from_axes(fsdp=-1, expert=16), + mesh_rules=( + ( + "tpu-v5p-(1024|2048)", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(data=-1, expert=16, fsdp=16) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=offload_attention_proj_policy, + ), + } + ), + ], + ), + ), + ( + "tpu-v6e-256-4", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(data=-1, expert=16, fsdp=16) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=offload_attention_proj_policy, + ), + } + ), + ], + ), + ), + ( + "tpu-v6e-256", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(data=-1, expert=16, fsdp=16) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=offload_attention_proj_policy, + ), + } + ), + GradientAccumulationModifier.default_config().set(grad_acc_steps=4), + ], + ), + ), + ), + ) + elif model_size == "Switch-XXL": + # Num of parameters: 520B. + trainer_kwargs = dict( + model_kwargs=dict( + num_layers=24, + hidden_dim=64 * 128, + ffn_dim=scaled_hidden_dim(scale=2.5, round_up_to_multiples_of=128), + num_heads=64, + num_kv_heads=8, + num_experts=NUM_EXPERTS[model_size], + train_capacity_factor=2.0, + num_groups=2, + ffn_structure="hybridnorm", + # MoE layer every 2 layers. + ffn_layer_types=[ + "dense", + "sparse", + ], + ), + learner_kwargs=dict(peak_lr=0.01, weight_decay=1e-4, lr_warmup_steps=5_000), + max_sequence_length=max_sequence_length, + train_batch_size=tokens_per_batch // max_sequence_length, # 8M tokens. + max_step=250_000, # Most of the evals were done at 100k steps in the paper. + # TODO(kelvin-zou): not verified with real job. + mesh_shape=mesh_shape_from_axes(fsdp=-1, expert=16, model=8), + ) + # pylint: enable=use-dict-literal + else: + raise NotImplementedError(f"Unknown model size {model_size}.") + + merged_trainer_kwargs = common_trainer_kwargs() + merged_trainer_kwargs.update( + {k: v for k, v in trainer_kwargs.items() if k not in ("model_kwargs", "learner_kwargs")} + ) + + # Update the model_kwargs + model_kwargs: dict[str, Any] = merged_trainer_kwargs.pop("model_kwargs") + model_kwargs.update(trainer_kwargs.get("model_kwargs", {})) + model_kwargs.setdefault("vocab_size", vocab_size) + + learner_kwargs: dict[str, Any] = merged_trainer_kwargs.pop("learner_kwargs") + learner_kwargs.update(trainer_kwargs.get("learner_kwargs", {})) + + mesh_shape = merged_trainer_kwargs.get("mesh_shape", mesh_shape_from_axes(data=-1)) + merged_trainer_kwargs["model_cfg"] = model_config( + flash_attention=flash_attention, mesh_shape=mesh_shape, **model_kwargs + ) + # If a model is smaller than the base model, do not scale. + linear_layer_lr_multiplier = min(_BASE_MODEL_HIDDEN_DIM / model_kwargs["hidden_dim"], 1.0) + merged_trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config( + max_step=trainer_kwargs["max_step"], + # Enable mup-simple. + adam_update_transformation=mup_simple_adam_update_transformation( + linear_layer_lr_multiplier, + ), + **learner_kwargs, + ) + + return merged_trainer_kwargs + + +def model_config( + *, + num_layers: int, + hidden_dim: int, + num_heads: int, + num_kv_heads: int, + num_experts: int, + vocab_size: int, + train_capacity_factor: float, + num_groups: int, + ffn_layer_types: Sequence[Literal["dense", "sparse"]], + ffn_dim: Union[int, config.FunctionConfigBase], + dropout_rate: float = 0.0, + flash_attention: bool = False, + mesh_shape: Union[MeshShape, HybridMeshShape], + **kwargs, +) -> causal_lm.Model.Config: + """Returns an LM model config based on the given hyperparams. + + Args: + num_layers: The number of Transformer Layers. + hidden_dim: The Transformer layer input/output dim. + num_heads: The number of attention heads. + num_kv_heads: The number of attention KV heads. + num_experts: The number of experts in the MoE layer. + vocab_size: The vocabulary size. + train_capacity_factor: The train capacity factor for the MoE layer. + ffn_layer_types: The types of layer in the feed-forward network, Options: [dense, sparse]. + dropout_rate: The dropout rate applied throughout the model. + Defaults to 0.0 (i.e. no dropout). + ffn_dim: The feed-forward dimension or config function. + If None, defaults to a setting from https://arxiv.org/abs/2002.05202. + flash_attention: If True, use flash attention implementation. + mesh_shape: the mesh shape, used to infer the outer batch size. + kwargs: Default kwargs forwarded to `common_model_config`. + + Returns: + A causal LM config. + """ + # Use RoPE by default. + # RoPE for positional encodings. + # `CausalAttentionLogitBiasLayer` is already applied in the attention impl. + attention_mask = None + # RoPE embeddings: https://arxiv.org/abs/2104.09864. + attention_qkv_linear = RoFormerQKVLinear.default_config().set( + input_linear=FusedGroupedQKVLinear.default_config().set( + num_kv_heads=num_kv_heads, + ), + rotary_value=False, + ) + attention_qkv_linear.rope_pos_emb_layer.theta = 5e5 + norm_cfg = RMSNorm.default_config().set(eps=1e-5, forward_dtype=None) + + transformer_layer_cfg = TransformerLayer.default_config() + if flash_attention: + transformer_layer_cfg.self_attention.attention = flash_attention_config() + else: + transformer_layer_cfg.self_attention.attention = GroupedQueryAttention.default_config() + transformer_layer_cfg.self_attention.attention.set( + # Use q/k-norm in keeping with: + # + query_scale=ScaleQuery.default_config().set(norm=norm_cfg.clone()), + key_scale=ScaleKey.default_config().set(norm=norm_cfg.clone()), + ) + outer_batch_size = get_outer_batch_from_mesh( + mesh_axis_names=MESH_AXIS_NAMES, + outer_batch_axis_names=MOE_OUTER_BATCH_AXIS_NAMES, + mesh_shape=mesh_shape, + ) + expert_config = TransformerFeedForwardMoE.default_config().set( + outer_batch=outer_batch_size, + num_experts=num_experts, + input_dim=hidden_dim, + num_groups=num_groups, + dim_to_mesh_axis_map=MOE_DIM_TO_MESH_AXIS_MAP, + ) + expert_config.gating.train_capacity_factor = train_capacity_factor + + emb_cfg: TransformerTextEmbeddings.Config = TransformerTextEmbeddings.default_config().set( + pos_emb=None + ) + emb_cfg.token_emb.param_partition_spec = (("expert", "fsdp", "seq"), "model") + cfg = common_model_config( + num_layers=num_layers, + hidden_dim=hidden_dim, + num_heads=num_heads, + vocab_size=vocab_size, + # SwiGLU from https://arxiv.org/abs/2002.05202. + activation_fn=("nn.silu", "linear"), + ffn_dim=ffn_dim, + normalization=norm_cfg, + dropout_rate=dropout_rate, + emb_cfg=emb_cfg, + # Since we pass `layer_cfg`, this is already set. + attention_cfg=None, + attention_mask=attention_mask, + attention_qkv_linear=attention_qkv_linear, + layer_cfg=transformer_layer_cfg, + ffn_layer_types=ffn_layer_types, + expert_cfg=expert_config, + **kwargs, + ) + if flash_attention: + cfg.decoder.transformer.layer.remat_spec = RematSpec( + prevent_cse=False, policy=jax_remat_policies.dots_saveable + ) + return cfg + + +def trainer_configs( + train_input_source: SourceBuilder, + eval_input_sources: SourceBuilder, +) -> dict[str, TrainerConfigFn]: + """Returns a mapping from config_name to TrainerConfigFn's. + + Args: + train_input_source: A callable (vocab_size, max_sequence_length) -> input source config. + eval_input_soruces: A callable (vocab_size, max_sequence_length) -> eval input sources. + """ + arch = "envy" + config_map = {} + vocab_size = VOCAB_SIZE + for model_size in MODEL_SIZES: + seq_len = MAX_SEQUENCE_LENGTH[model_size] + config_name = make_config_name(arch=arch, model_size=model_size) + kwargs = get_trainer_kwargs( + model_size, + vocab_size=vocab_size, + # Use default flash attention for 3B and 7B models. + flash_attention=(model_size != "test"), + max_sequence_length=seq_len, + ) + + # Test models sometimes override it to a very small length. + seq_len = kwargs.pop("max_sequence_length", seq_len) + + # pylint: disable-next=unexpected-keyword-arg,missing-kwoa + config_map[config_name] = get_trainer_config_fn( + train_input_source=train_input_source( + vocab_size=vocab_size, max_sequence_length=seq_len + ), + evalers=evaler_config_dict( + eval_input_sources(vocab_size=vocab_size, max_sequence_length=seq_len), + ), + **kwargs, + ) + # Only Switch-Base model size is runnable on a single node mode. + if model_size == "Switch-Base": + + def make_single_host_config(base_config_name: str) -> SpmdTrainer.Config: + """Make a single-host variant of the base config.""" + + # pytype: disable=annotation-type-mismatch + cfg: SpmdTrainer.Config = config_map[base_config_name]().clone() + # pytype: enable=annotation-type-mismatch + cfg.input.batcher.feed_batch_size = 8 + for evaler in cfg.evalers.values(): + evaler.input.batcher.feed_batch_size = 8 + remat_modifier = ( + RematSpecModifier.default_config() + .set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=jax_remat_policies.nothing_saveable, + ), + } + ) + .instantiate() + ) + cfg = remat_modifier(cfg) + return cfg + + # Make single-host config + make_single_host_config_func = functools.partial(make_single_host_config, config_name) + config_map[f"{config_name}-single-host"] = make_single_host_config_func + + return config_map diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 57c910c70..1e60849a8 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -130,6 +130,47 @@ class Version(enum.Enum): }, } + +def offload_dots_saveable_policy(*_, **__): + """A rematerialization policy function used in RematSpec to offload dot_general_p + operations from device to pinned host memory. + + Args: + *_: Ignored positional arguments. + **__: Ignored keyword arguments. + + Returns: + A policy function that offloads dot_general_p from device to pinned host + memory. + """ + return config_for_function(extended_checkpoint_policies.offload_dots_saveable).set( + offload_src="device", offload_dst="pinned_host" + ) + + +def offload_attention_proj_policy(*_, **__): + """A rematerialization policy function used in RematSpec to offload attention + projection intermediates during model execution. + + Args: + *_: Ignored positional arguments. + **__: Ignored keyword arguments. + + Returns: + A checkpoint policy function that offloads native attention projection intermediates + from device to pinned host memory, enabling memory-efficient training with checkpoint + support. + """ + return config_for_function( + extended_checkpoint_policies.save_and_offload_only_these_names_regex + ).set( + names_which_can_be_saved=None, + names_which_can_be_offloaded=RematRegexSavePatterns.NATIVE_ATTENTION.value, + offload_src="device", + offload_dst="pinned_host", + ) + + # Llama3 uses 16m tokens after 2.87T tokens. # https://arxiv.org/pdf/2407.21783 TOKENS_PER_BATCH = { @@ -258,18 +299,7 @@ def get_trainer_kwargs( rope_theta = ROPE_THETA[version] trn2_config = _generate_trn2_custom_configs(model_size, version=version) - offload_dots_saveable_policy = config_for_function( - extended_checkpoint_policies.offload_dots_saveable - ).set(offload_src="device", offload_dst="pinned_host") - # To make it work better with v3 8k sequence length. - offload_attention_proj_policy = config_for_function( - extended_checkpoint_policies.save_and_offload_only_these_names_regex - ).set( - names_which_can_be_saved=None, - names_which_can_be_offloaded=RematRegexSavePatterns.NATIVE_ATTENTION.value, - offload_src="device", - offload_dst="pinned_host", - ) + # dict() is more readable here. # pylint: disable=use-dict-literal if model_size == "test":