Skip to content

Commit 05d0119

Browse files
simplify init
1 parent f84cc1e commit 05d0119

File tree

2 files changed

+5
-9
lines changed

2 files changed

+5
-9
lines changed

keras/src/backend/jax/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
from keras.src.backend.jax.core import IS_THREAD_SAFE
1212
from keras.src.backend.jax.core import SUPPORTS_RAGGED_TENSORS
1313
from keras.src.backend.jax.core import SUPPORTS_SPARSE_TENSORS
14-
15-
if is_nnx_enabled():
16-
from keras.src.backend.jax.core import NnxVariable as Variable
17-
else:
18-
from keras.src.backend.jax.core import JaxVariable as Variable
14+
from keras.src.backend.jax.core import Variable
1915
from keras.src.backend.jax.core import cast
2016
from keras.src.backend.jax.core import compute_output_spec
2117
from keras.src.backend.jax.core import cond

keras/src/backend/jax/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __jax_array__(self):
5858
return self.value
5959

6060

61-
_JAX_VARIABLE_TYPE = JaxVariable
61+
Variable = JaxVariable
6262
if config.is_nnx_enabled():
6363
from flax import nnx
6464

@@ -231,7 +231,7 @@ def value(self):
231231
)
232232
return self._maybe_autocast(current_value)
233233

234-
_JAX_VARIABLE_TYPE = NnxVariable
234+
Variable = NnxVariable
235235

236236

237237
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
@@ -247,7 +247,7 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
247247
# an existing distributed jax array will raise error.
248248
return x
249249

250-
if isinstance(x, _JAX_VARIABLE_TYPE):
250+
if isinstance(x, Variable):
251251
if dtype is not None and x.dtype != dtype:
252252
return x.value.astype(dtype)
253253
return x.value
@@ -531,7 +531,7 @@ def fori_loop(lower, upper, body_fun, init_val):
531531

532532

533533
def stop_gradient(variable):
534-
if isinstance(variable, _JAX_VARIABLE_TYPE):
534+
if isinstance(variable, Variable):
535535
variable = variable.value
536536
return jax.lax.stop_gradient(variable)
537537

0 commit comments

Comments
 (0)