Skip to content

using a single random key as (unmapped) input to a vmap'ped function. #30711

Answered by jakevdp
hrbigelow asked this question in General
Discussion options

You must be logged in to vote

If you use a single unmapped random key, then the same key will be used within each batch of the computation: this is working by design. If you'd like a different key within each batch, then you can split/fold the key and map over it, e.g. like this:

>>> jax.vmap(jax.random.categorical)(jax.random.split(keys, 10), jnp.broadcast_to(logits, (10, *logits.shape)))
Array([2, 1, 1, 3, 2, 0, 3, 3, 3, 1], dtype=int32)

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by hrbigelow
Comment options

You must be logged in to vote
2 replies
@jakevdp
Comment options

@hrbigelow
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants