diff --git a/axlearn/common/launch.py b/axlearn/common/launch.py index a0b4b08cf..a1e12c209 100644 --- a/axlearn/common/launch.py +++ b/axlearn/common/launch.py @@ -114,25 +114,13 @@ def setup(): logging.info("LIBTPU_INIT_ARGS='%s'", os.environ["LIBTPU_INIT_ARGS"]) with _init_context(): - if FLAGS.jax_backend == "proxy": - # AXLearn assumes rbg PRNG implementation and restore from checkpoint - # will fail on pathways if this isn't set. This is due shape of [4] - # being hardcoded here: - # https://github.com/apple/axlearn/blob/8bb4421e62c815ef9f1ba3679c3277b8bbc6a449/axlearn/common/trainer.py#L330 - jax.config.update("jax_default_prng_impl", "rbg") - - # pylint: disable-next=import-error,import-outside-toplevel - import pathwaysutils # pytype: disable=import-error - - pathwaysutils.initialize() - else: - setup_spmd( - distributed_coordinator=FLAGS.distributed_coordinator, - num_processes=FLAGS.num_processes, - process_id=FLAGS.process_id, - jax_backend=FLAGS.jax_backend, - initialization_timeout=FLAGS.initialization_timeout, - ) + setup_spmd( + distributed_coordinator=FLAGS.distributed_coordinator, + num_processes=FLAGS.num_processes, + process_id=FLAGS.process_id, + jax_backend=FLAGS.jax_backend, + initialization_timeout=FLAGS.initialization_timeout, + ) if FLAGS.jax_profiler_port is not None: # Start jax.profiler for Tensorboard and profiling in open source. diff --git a/axlearn/common/utils_spmd.py b/axlearn/common/utils_spmd.py index 1917e99a3..708b20439 100644 --- a/axlearn/common/utils_spmd.py +++ b/axlearn/common/utils_spmd.py @@ -22,28 +22,37 @@ def setup( """Sets up the JAX environment for SPMD. Args: - jax_backend: The distributed backend, which can be "cpu", "gpu", or "tpu". - distributed_coordinator: The distributed coordinator address (in the form of :). - Needed only for `jax_backend != "tpu"` and `num_processes > 1`. Otherwise, the - coordinator will be configured automatically when `num_processes` and `process_id` are - provided. - num_processes: The number of processes. Needed only if distributed initialization is desired - for `jax_backend != "tpu"`. - process_id: The process ID (the process rank). Needed only if distributed initialization is - desired for `jax_backend != "tpu"`. - initialization_timeout: The jax distributed initialization timeout in seconds. If None, uses - jax default. + jax_backend: The distributed backend. Can be "cpu", "gpu", "tpu", or "proxy". + distributed_coordinator: The distributed coordinator address (e.g., ":"). + If jax_backend is "tpu", this may be automatically inferred by JAX. + If jax_backend is "proxy", this is ignored. + num_processes: The number of processes. + If jax_backend is "tpu", this may be automatically inferred by JAX. + If jax_backend is "proxy", this is ignored. + process_id: The process ID (the process rank). + If jax_backend is "tpu", this may be automatically inferred by JAX. + If jax_backend is "proxy", this is ignored. + initialization_timeout: The jax distributed initialization timeout in seconds. + If None, uses jax default. + If jax_backend is "proxy", this is ignored. Raises: ValueError: If any of the following conditions are met: - * distributed_coordinator, num_processes, or process_id are not None when - jax_backend is "tpu"; - * one of num_processes or process_id is None when jax_backend is not "tpu"; - * distributed_coordinator is None when jax_backend is not "tpu" and num_processes > 1. + * `jax_backend` not in ("tpu", "proxy") and (`num_processes` is None or `process_id` is + None). + * `jax_backend` not in ("tpu", "proxy"), `num_processes` > 1, and + `distributed_coordinator` is None. """ # Use a GSPMD-friendly PRNG implementation. jax.config.update("jax_default_prng_impl", "rbg") + if jax_backend == "proxy": + # pylint: disable-next=import-error,import-outside-toplevel + import pathwaysutils # pytype: disable=import-error + + pathwaysutils.initialize() + return + global _jax_distributed_initialized # pylint: disable=global-statement if not _jax_distributed_initialized: init_kwargs = {}