Use of jax.device_put
vs jax.lax.with_sharding_contraints
#31105
Replies: 2 comments 3 replies
-
with_sharding_constraint can be used in eager mode too but it's just an identity jit i.e. Under a jit, with_sharding_constraint is a strict constraint which compiler has to respect.
Usually the recommendation is to use device_put outside |
Beta Was this translation helpful? Give feedback.
-
Thanks for the reply, but some things are still not clear to me. When using My understanding is that |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
As the question says, it is unclear what the different uses of these two are based on the documentation. First some context. The question stems from @mjo22's effort to rewrite the auto-parallelisation tutorial for equinox and my own attempt at implementing the code in the tutorial, which lead to me noticing a duplicate use of equinox's
filter_shard
function.Essentially, the issue is that with 0.7.0, Jax recommends using
device_put
to shard arrays in eager mode andlax.with_sharding_constraints
being used inside jitted functions.equinox.filter_shard
uses the latter and while testing I found that this works outside jitted functions just fine. The first question is, is there any downside to using alax
function such as this one in eager mode? My understanding is that there isn't in this case and that the only difference between the two is thatdevice_put
forces a particular sharding layout on the arrays which the compiler must stick to whilewith_sharding_constraints
only provides a suggestion which the compiler only needs to consider but is otherwise free to optimize the sharding layout as wanted.If this is true, then how could we achieve the same result in eager mode, that is, before passing the arrays to jitted functions? Is the only way to use
in_sharding
withLayout.AUTO
?This is important as it will influence how the equinox filter functions are re-written, especially
eqx.filter_jit
.Beta Was this translation helpful? Give feedback.
All reactions