@@ -58,7 +58,7 @@ def __jax_array__(self):
58
58
return self .value
59
59
60
60
61
- _JAX_VARIABLE_TYPE = JaxVariable
61
+ Variable = JaxVariable
62
62
if config .is_nnx_enabled ():
63
63
from flax import nnx
64
64
@@ -231,7 +231,7 @@ def value(self):
231
231
)
232
232
return self ._maybe_autocast (current_value )
233
233
234
- _JAX_VARIABLE_TYPE = NnxVariable
234
+ Variable = NnxVariable
235
235
236
236
237
237
def convert_to_tensor (x , dtype = None , sparse = None , ragged = None ):
@@ -247,7 +247,7 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
247
247
# an existing distributed jax array will raise error.
248
248
return x
249
249
250
- if isinstance (x , _JAX_VARIABLE_TYPE ):
250
+ if isinstance (x , Variable ):
251
251
if dtype is not None and x .dtype != dtype :
252
252
return x .value .astype (dtype )
253
253
return x .value
@@ -531,7 +531,7 @@ def fori_loop(lower, upper, body_fun, init_val):
531
531
532
532
533
533
def stop_gradient (variable ):
534
- if isinstance (variable , _JAX_VARIABLE_TYPE ):
534
+ if isinstance (variable , Variable ):
535
535
variable = variable .value
536
536
return jax .lax .stop_gradient (variable )
537
537
0 commit comments