JAX numpy to Numpy is CRAZILY SLOW #30891
-
I wrote an optimization function using jax, and used The time used to complete over 100 optimization is just several seconds, but to convert it into python float, it takes over 1 minuets! I tried several methods, they are all slow.
Did I do anything wrong or is there a faster way. Can not believe it :( |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
In general, the np.asarray approach is preferred as it returns a view of the jax array (provided that dtypes match and jax array is on CPU) which is expected to be almost instant operation. What is the size of these arrays and how much RAM is available in your system? Also, what is the dtype of the JAX arrays? |
Beta Was this translation helpful? Give feedback.
-
Please read https://docs.jax.dev/en/latest/async_dispatch.html Are you sure you are timing the conversion itself, as opposed to the computation? Try adding a |
Beta Was this translation helpful? Give feedback.
Please read https://docs.jax.dev/en/latest/async_dispatch.html
Are you sure you are timing the conversion itself, as opposed to the computation? Try adding a
jax.block_until_ready(test_1)
before callingnp.asarray
.