-
In the document of
It is unclear what does "static" mean here. According to my understanding, From my experiments, this does indeed seem to be the case. See the code below: import jax
import jax.numpy as jnp
from functools import partial
from jax.sharding import PartitionSpec as P
import timeit
mesh = jax.make_mesh((2,), ('d',))
@partial(jax.shard_map, mesh=mesh, in_specs=(None, P()), out_specs=P())
def f(x, y):
z = 0
for i in range(10):
z = z + x + y
return z
@partial(jax.shard_map, mesh=mesh, in_specs=(P(), P()), out_specs=P())
def g(x, y):
z = 0
for i in range(10):
z = z + x + y
return z
x = jnp.array(1)
y = jnp.array(2)
z1 = f(x, y)
z2 = g(x, y)
print("start")
print("f")
print(timeit.timeit(lambda: jax.block_until_ready(f(x, y)), number=50))
print("g")
print(timeit.timeit(lambda: jax.block_until_ready(g(x, y)), number=50)) The output is:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! I agree we shouldn't use the term "static" here. An in_spec of @jax.shard_map(in_specs=(None, <A>), <B>)
def f(x, <C>):
<D>
f(x, <E>) means the same thing as @jax.shard_map(in_specs=<A>, <B>)
def f_(<C>):
<D>
f_(<E>) Everything else follows from that, but just to be concrete:
There's not much of a difference, because closed-over arraylikes are treated as if they have specs of
Under a Outside of a What do you think? |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
I agree we shouldn't use the term "static" here.
An in_spec of
None
means to treat that input the same as if it had been closed over, i.e.means the same thing as
Everything else follows from that, but just to be concrete:
There's not much of a difference, because closed-over arraylikes are treated as if they have specs of
P()
.Und…