@@ -22,24 +22,26 @@ def setup(
22
22
"""Sets up the JAX environment for SPMD.
23
23
24
24
Args:
25
- jax_backend: The distributed backend, which can be "cpu", "gpu", or "tpu".
26
- distributed_coordinator: The distributed coordinator address (in the form of <host>:<port>).
27
- Needed only for `jax_backend != "tpu"` and `num_processes > 1`. Otherwise, the
28
- coordinator will be configured automatically when `num_processes` and `process_id` are
29
- provided.
30
- num_processes: The number of processes. Needed only if distributed initialization is desired
31
- for `jax_backend != "tpu"`.
32
- process_id: The process ID (the process rank). Needed only if distributed initialization is
33
- desired for `jax_backend != "tpu"`.
34
- initialization_timeout: The jax distributed initialization timeout in seconds. If None, uses
35
- jax default.
25
+ jax_backend: The distributed backend. Can be "cpu", "gpu", "tpu", or "proxy".
26
+ distributed_coordinator: The distributed coordinator address (e.g., "<host>:<port>").
27
+ If jax_backend is "tpu", this may be automatically inferred by JAX.
28
+ If jax_backend is "proxy", this is ignored.
29
+ num_processes: The number of processes.
30
+ If jax_backend is "tpu", this may be automatically inferred by JAX.
31
+ If jax_backend is "proxy", this is ignored.
32
+ process_id: The process ID (the process rank).
33
+ If jax_backend is "tpu", this may be automatically inferred by JAX.
34
+ If jax_backend is "proxy", this is ignored.
35
+ initialization_timeout: The jax distributed initialization timeout in seconds.
36
+ If None, uses jax default.
37
+ If jax_backend is "proxy", this is ignored.
36
38
37
39
Raises:
38
40
ValueError: If any of the following conditions are met:
39
- * distributed_coordinator, num_processes, or process_id are not None when
40
- jax_backend is "tpu";
41
- * one of num_processes or process_id is None when jax_backend is not "tpu";
42
- * distributed_coordinator is None when jax_backend is not "tpu" and num_processes > 1 .
41
+ * `jax_backend` not in ("tpu", "proxy") and (`num_processes` is None or `process_id` is
42
+ None).
43
+ * `jax_backend` not in ("tpu", "proxy"), `num_processes` > 1, and
44
+ `distributed_coordinator` is None .
43
45
"""
44
46
# Use a GSPMD-friendly PRNG implementation.
45
47
jax .config .update ("jax_default_prng_impl" , "rbg" )
0 commit comments