Skip to content

Commit c7783cc

Browse files
committed
update docstring for setup
1 parent f7d548d commit c7783cc

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

axlearn/common/utils_spmd.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,26 @@ def setup(
2222
"""Sets up the JAX environment for SPMD.
2323
2424
Args:
25-
jax_backend: The distributed backend, which can be "cpu", "gpu", or "tpu".
26-
distributed_coordinator: The distributed coordinator address (in the form of <host>:<port>).
27-
Needed only for `jax_backend != "tpu"` and `num_processes > 1`. Otherwise, the
28-
coordinator will be configured automatically when `num_processes` and `process_id` are
29-
provided.
30-
num_processes: The number of processes. Needed only if distributed initialization is desired
31-
for `jax_backend != "tpu"`.
32-
process_id: The process ID (the process rank). Needed only if distributed initialization is
33-
desired for `jax_backend != "tpu"`.
34-
initialization_timeout: The jax distributed initialization timeout in seconds. If None, uses
35-
jax default.
25+
jax_backend: The distributed backend. Can be "cpu", "gpu", "tpu", or "proxy".
26+
distributed_coordinator: The distributed coordinator address (e.g., "<host>:<port>").
27+
If jax_backend is "tpu", this may be automatically inferred by JAX.
28+
If jax_backend is "proxy", this is ignored.
29+
num_processes: The number of processes.
30+
If jax_backend is "tpu", this may be automatically inferred by JAX.
31+
If jax_backend is "proxy", this is ignored.
32+
process_id: The process ID (the process rank).
33+
If jax_backend is "tpu", this may be automatically inferred by JAX.
34+
If jax_backend is "proxy", this is ignored.
35+
initialization_timeout: The jax distributed initialization timeout in seconds.
36+
If None, uses jax default.
37+
If jax_backend is "proxy", this is ignored.
3638
3739
Raises:
3840
ValueError: If any of the following conditions are met:
39-
* distributed_coordinator, num_processes, or process_id are not None when
40-
jax_backend is "tpu";
41-
* one of num_processes or process_id is None when jax_backend is not "tpu";
42-
* distributed_coordinator is None when jax_backend is not "tpu" and num_processes > 1.
41+
* `jax_backend` not in ("tpu", "proxy") and (`num_processes` is None or `process_id` is
42+
None).
43+
* `jax_backend` not in ("tpu", "proxy"), `num_processes` > 1, and
44+
`distributed_coordinator` is None.
4345
"""
4446
# Use a GSPMD-friendly PRNG implementation.
4547
jax.config.update("jax_default_prng_impl", "rbg")

0 commit comments

Comments
 (0)