Skip to content

What does in_specs=None mean in shard_map, and what is the difference with in_specs=P()? #31107

Answered by mattjj
HeavyCrab asked this question in Q&A
Discussion options

You must be logged in to vote

Thanks for the question!

I agree we shouldn't use the term "static" here.

An in_spec of None means to treat that input the same as if it had been closed over, i.e.

@jax.shard_map(in_specs=(None, <A>), <B>)
def f(x, <C>):
  <D>

f(x, <E>)

means the same thing as

@jax.shard_map(in_specs=<A>, <B>)
def f_(<C>):
  <D>

f_(<E>)

Everything else follows from that, but just to be concrete:

But if it’s a jax.Array, what’s the difference between this and P()?

There's not much of a difference, because closed-over arraylikes are treated as if they have specs of P().

If I mark it as None, does that imply some GPU-to-GPU communication occurs at every operation, making it less efficient than P()?

Und…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@HeavyCrab
Comment options

Answer selected by HeavyCrab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants