Tracing parameter initialization functions that are sometimes deterministic but need to curry rng key device #31106
Unanswered
james-sony
asked this question in
Q&A
Replies: 1 comment
-
Looking further I see JAX has a device context manager. Maybe the right solution then is to not set the device on the array creation and just use the device manager? For example: def my_func(key):
return jnp.zeros([10]) # remove device setting and then calling code that wants precise control over the allocation just does with jax.default_device(key.device): # let key be defined somewhere above
y = my_func(key) That way, whatever combination of rng and deterministic allocations my_func does, they will be on the right device. At the same time, since the function doesn't require referencing the device, its safe to run it through traced functions like Does that sound right? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm building neural net parameter initializer functions. Some are random, while others are deterministic (like bias term zeros). I want to allow device placement through these functions. I believe the usual way Jax handles this is the rng operation device is determined by the device of the rng key. Neat that sounds fine. But for deterministic operations, I don't get the for free. However, I can do something like
jnp.zeros(shape, device=rng_key.device)
That works too. Cool.
But now comes the problem. A handy utility is to use
jax.eval_shape
on the initialization process itself so that I can get the shapes without actually allocating parameters.This works if I don't curry the rng key device, because in the tracer, the tracer objects lose the device argument.
Here's a simple example demonstrating that:
The error is:
I can think of hacks around this specific case, but I feel like there will be edges that come again. Is there a good principled way to handle this?
For example, this hack works, but feels... ugly. And like i'm waiting for it to bite me.
Beta Was this translation helpful? Give feedback.
All reactions