From 19617d502720211131783e27a7fedc05712b7e0a Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Mon, 9 Jun 2025 12:21:51 -0700 Subject: [PATCH 1/4] Pathways: reuse setup_spmd for pathways init --- axlearn/common/launch.py | 26 +++------ axlearn/common/utils_spmd.py | 101 +++++++++++++++++++---------------- 2 files changed, 62 insertions(+), 65 deletions(-) diff --git a/axlearn/common/launch.py b/axlearn/common/launch.py index a0b4b08cf..a1e12c209 100644 --- a/axlearn/common/launch.py +++ b/axlearn/common/launch.py @@ -114,25 +114,13 @@ def setup(): logging.info("LIBTPU_INIT_ARGS='%s'", os.environ["LIBTPU_INIT_ARGS"]) with _init_context(): - if FLAGS.jax_backend == "proxy": - # AXLearn assumes rbg PRNG implementation and restore from checkpoint - # will fail on pathways if this isn't set. This is due shape of [4] - # being hardcoded here: - # https://github.com/apple/axlearn/blob/8bb4421e62c815ef9f1ba3679c3277b8bbc6a449/axlearn/common/trainer.py#L330 - jax.config.update("jax_default_prng_impl", "rbg") - - # pylint: disable-next=import-error,import-outside-toplevel - import pathwaysutils # pytype: disable=import-error - - pathwaysutils.initialize() - else: - setup_spmd( - distributed_coordinator=FLAGS.distributed_coordinator, - num_processes=FLAGS.num_processes, - process_id=FLAGS.process_id, - jax_backend=FLAGS.jax_backend, - initialization_timeout=FLAGS.initialization_timeout, - ) + setup_spmd( + distributed_coordinator=FLAGS.distributed_coordinator, + num_processes=FLAGS.num_processes, + process_id=FLAGS.process_id, + jax_backend=FLAGS.jax_backend, + initialization_timeout=FLAGS.initialization_timeout, + ) if FLAGS.jax_profiler_port is not None: # Start jax.profiler for Tensorboard and profiling in open source. diff --git a/axlearn/common/utils_spmd.py b/axlearn/common/utils_spmd.py index 1917e99a3..add967d0b 100644 --- a/axlearn/common/utils_spmd.py +++ b/axlearn/common/utils_spmd.py @@ -44,55 +44,64 @@ def setup( # Use a GSPMD-friendly PRNG implementation. jax.config.update("jax_default_prng_impl", "rbg") - global _jax_distributed_initialized # pylint: disable=global-statement - if not _jax_distributed_initialized: - init_kwargs = {} - if initialization_timeout is not None: - init_kwargs["initialization_timeout"] = initialization_timeout + if jax_backend == "proxy": + # pylint: disable-next=import-error,import-outside-toplevel + import pathwaysutils # pytype: disable=import-error - if jax_backend == "tpu": - if (distributed_coordinator is None) ^ (process_id is None): - raise ValueError( - "distributed_coordinator and process_id should be both None or both " - f"not-None, but got {distributed_coordinator=}, {process_id=}" - ) - init_kwargs.update( - coordinator_address=distributed_coordinator, - process_id=process_id, - # This is optional. - num_processes=num_processes, - ) - else: - if distributed_coordinator is None and num_processes is None and process_id is None: - logging.info( - "Skipping distributed initialization for %s backend, " - "since distributed_coordinator, num_processes, and process_id are all None.", - jax_backend, - ) - return + pathwaysutils.initialize() + else: + global _jax_distributed_initialized # pylint: disable=global-statement + if not _jax_distributed_initialized: + init_kwargs = {} + if initialization_timeout is not None: + init_kwargs["initialization_timeout"] = initialization_timeout - if num_processes is None or process_id is None: - raise ValueError( - "num_processes and process_id should be provided together " - f"if distributed initialization is desired for backend {jax_backend}. " - f"Instead, got num_processes={num_processes}, process_id={process_id}." + if jax_backend == "tpu": + if (distributed_coordinator is None) ^ (process_id is None): + raise ValueError( + "distributed_coordinator and process_id should be both None or both " + f"not-None, but got {distributed_coordinator=}, {process_id=}" + ) + init_kwargs.update( + coordinator_address=distributed_coordinator, + process_id=process_id, + # This is optional. + num_processes=num_processes, ) + else: + if distributed_coordinator is None and num_processes is None and process_id is None: + logging.info( + "Skipping distributed initialization for %s backend, " + "since distributed_coordinator, num_processes, and process_id " + "are all None.", + jax_backend, + ) + return - if not distributed_coordinator: - if num_processes == 1: - distributed_coordinator = f"localhost:{portpicker.pick_unused_port()}" - else: - raise ValueError(f"Unknown distributed_coordinator: {distributed_coordinator}") + if num_processes is None or process_id is None: + raise ValueError( + "num_processes and process_id should be provided together " + f"if distributed initialization is desired for backend {jax_backend}. " + f"Instead, got num_processes={num_processes}, process_id={process_id}." + ) - init_kwargs.update( - coordinator_address=distributed_coordinator, - num_processes=num_processes, - process_id=process_id, - ) - if jax_backend == "gpu": - # jax 0.4.34 introduced a change to cluster auto-detection behavior, supplying - # local_device_ids arg allows us to maintain expected behavior - init_kwargs["local_device_ids"] = list(range(8)) + if not distributed_coordinator: + if num_processes == 1: + distributed_coordinator = f"localhost:{portpicker.pick_unused_port()}" + else: + raise ValueError( + f"Unknown distributed_coordinator: {distributed_coordinator}" + ) + + init_kwargs.update( + coordinator_address=distributed_coordinator, + num_processes=num_processes, + process_id=process_id, + ) + if jax_backend == "gpu": + # jax 0.4.34 introduced a change to cluster auto-detection behavior, supplying + # local_device_ids arg allows us to maintain expected behavior + init_kwargs["local_device_ids"] = list(range(8)) - jax.distributed.initialize(**init_kwargs) - _jax_distributed_initialized = True + jax.distributed.initialize(**init_kwargs) + _jax_distributed_initialized = True From dd4425fe101fc2f79928620bc64f05a399c2522c Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 10 Jun 2025 10:57:37 -0700 Subject: [PATCH 2/4] early return --- axlearn/common/utils_spmd.py | 105 +++++++++++++++++------------------ 1 file changed, 52 insertions(+), 53 deletions(-) diff --git a/axlearn/common/utils_spmd.py b/axlearn/common/utils_spmd.py index add967d0b..b7e740357 100644 --- a/axlearn/common/utils_spmd.py +++ b/axlearn/common/utils_spmd.py @@ -49,59 +49,58 @@ def setup( import pathwaysutils # pytype: disable=import-error pathwaysutils.initialize() - else: - global _jax_distributed_initialized # pylint: disable=global-statement - if not _jax_distributed_initialized: - init_kwargs = {} - if initialization_timeout is not None: - init_kwargs["initialization_timeout"] = initialization_timeout - - if jax_backend == "tpu": - if (distributed_coordinator is None) ^ (process_id is None): - raise ValueError( - "distributed_coordinator and process_id should be both None or both " - f"not-None, but got {distributed_coordinator=}, {process_id=}" - ) - init_kwargs.update( - coordinator_address=distributed_coordinator, - process_id=process_id, - # This is optional. - num_processes=num_processes, + return + + global _jax_distributed_initialized # pylint: disable=global-statement + if not _jax_distributed_initialized: + init_kwargs = {} + if initialization_timeout is not None: + init_kwargs["initialization_timeout"] = initialization_timeout + + if jax_backend == "tpu": + if (distributed_coordinator is None) ^ (process_id is None): + raise ValueError( + "distributed_coordinator and process_id should be both None or both " + f"not-None, but got {distributed_coordinator=}, {process_id=}" + ) + init_kwargs.update( + coordinator_address=distributed_coordinator, + process_id=process_id, + # This is optional. + num_processes=num_processes, + ) + else: + if distributed_coordinator is None and num_processes is None and process_id is None: + logging.info( + "Skipping distributed initialization for %s backend, " + "since distributed_coordinator, num_processes, and process_id " + "are all None.", + jax_backend, ) - else: - if distributed_coordinator is None and num_processes is None and process_id is None: - logging.info( - "Skipping distributed initialization for %s backend, " - "since distributed_coordinator, num_processes, and process_id " - "are all None.", - jax_backend, - ) - return - - if num_processes is None or process_id is None: - raise ValueError( - "num_processes and process_id should be provided together " - f"if distributed initialization is desired for backend {jax_backend}. " - f"Instead, got num_processes={num_processes}, process_id={process_id}." - ) - - if not distributed_coordinator: - if num_processes == 1: - distributed_coordinator = f"localhost:{portpicker.pick_unused_port()}" - else: - raise ValueError( - f"Unknown distributed_coordinator: {distributed_coordinator}" - ) - - init_kwargs.update( - coordinator_address=distributed_coordinator, - num_processes=num_processes, - process_id=process_id, + return + + if num_processes is None or process_id is None: + raise ValueError( + "num_processes and process_id should be provided together " + f"if distributed initialization is desired for backend {jax_backend}. " + f"Instead, got num_processes={num_processes}, process_id={process_id}." ) - if jax_backend == "gpu": - # jax 0.4.34 introduced a change to cluster auto-detection behavior, supplying - # local_device_ids arg allows us to maintain expected behavior - init_kwargs["local_device_ids"] = list(range(8)) - jax.distributed.initialize(**init_kwargs) - _jax_distributed_initialized = True + if not distributed_coordinator: + if num_processes == 1: + distributed_coordinator = f"localhost:{portpicker.pick_unused_port()}" + else: + raise ValueError(f"Unknown distributed_coordinator: {distributed_coordinator}") + + init_kwargs.update( + coordinator_address=distributed_coordinator, + num_processes=num_processes, + process_id=process_id, + ) + if jax_backend == "gpu": + # jax 0.4.34 introduced a change to cluster auto-detection behavior, supplying + # local_device_ids arg allows us to maintain expected behavior + init_kwargs["local_device_ids"] = list(range(8)) + + jax.distributed.initialize(**init_kwargs) + _jax_distributed_initialized = True From 935db4616c4ff10a22e7867c9dfc625c5a9dfda2 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 10 Jun 2025 10:58:46 -0700 Subject: [PATCH 3/4] formatting --- axlearn/common/utils_spmd.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/axlearn/common/utils_spmd.py b/axlearn/common/utils_spmd.py index b7e740357..5f192cf58 100644 --- a/axlearn/common/utils_spmd.py +++ b/axlearn/common/utils_spmd.py @@ -73,8 +73,7 @@ def setup( if distributed_coordinator is None and num_processes is None and process_id is None: logging.info( "Skipping distributed initialization for %s backend, " - "since distributed_coordinator, num_processes, and process_id " - "are all None.", + "since distributed_coordinator, num_processes, and process_id are all None.", jax_backend, ) return From c7783cc424e3c4d7cd1a7af71b379fb0500cb7a9 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Mon, 23 Jun 2025 08:31:55 -0700 Subject: [PATCH 4/4] update docstring for setup --- axlearn/common/utils_spmd.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/axlearn/common/utils_spmd.py b/axlearn/common/utils_spmd.py index 5f192cf58..708b20439 100644 --- a/axlearn/common/utils_spmd.py +++ b/axlearn/common/utils_spmd.py @@ -22,24 +22,26 @@ def setup( """Sets up the JAX environment for SPMD. Args: - jax_backend: The distributed backend, which can be "cpu", "gpu", or "tpu". - distributed_coordinator: The distributed coordinator address (in the form of :). - Needed only for `jax_backend != "tpu"` and `num_processes > 1`. Otherwise, the - coordinator will be configured automatically when `num_processes` and `process_id` are - provided. - num_processes: The number of processes. Needed only if distributed initialization is desired - for `jax_backend != "tpu"`. - process_id: The process ID (the process rank). Needed only if distributed initialization is - desired for `jax_backend != "tpu"`. - initialization_timeout: The jax distributed initialization timeout in seconds. If None, uses - jax default. + jax_backend: The distributed backend. Can be "cpu", "gpu", "tpu", or "proxy". + distributed_coordinator: The distributed coordinator address (e.g., ":"). + If jax_backend is "tpu", this may be automatically inferred by JAX. + If jax_backend is "proxy", this is ignored. + num_processes: The number of processes. + If jax_backend is "tpu", this may be automatically inferred by JAX. + If jax_backend is "proxy", this is ignored. + process_id: The process ID (the process rank). + If jax_backend is "tpu", this may be automatically inferred by JAX. + If jax_backend is "proxy", this is ignored. + initialization_timeout: The jax distributed initialization timeout in seconds. + If None, uses jax default. + If jax_backend is "proxy", this is ignored. Raises: ValueError: If any of the following conditions are met: - * distributed_coordinator, num_processes, or process_id are not None when - jax_backend is "tpu"; - * one of num_processes or process_id is None when jax_backend is not "tpu"; - * distributed_coordinator is None when jax_backend is not "tpu" and num_processes > 1. + * `jax_backend` not in ("tpu", "proxy") and (`num_processes` is None or `process_id` is + None). + * `jax_backend` not in ("tpu", "proxy"), `num_processes` > 1, and + `distributed_coordinator` is None. """ # Use a GSPMD-friendly PRNG implementation. jax.config.update("jax_default_prng_impl", "rbg")