diff --git a/keras/src/ops/core.py b/keras/src/ops/core.py index f2ac3bafaf2..d86ca94f675 100644 --- a/keras/src/ops/core.py +++ b/keras/src/ops/core.py @@ -16,13 +16,16 @@ def call(self, f, xs): return backend.core.map(f, xs) def compute_output_spec(self, f, xs): - x = xs[0] - n = xs.shape[0] + x = tree.map_structure(lambda t: t[0], xs) + n = tree.flatten(xs)[0].shape[0] y = backend.compute_output_spec(f, x) - def append_batch_axis(x): + def append_batch_axis(t): return KerasTensor( - shape=(n,) + x.shape, dtype=x.dtype, sparse=x.sparse + shape=(n,) + t.shape, + dtype=t.dtype, + sparse=t.sparse, + ragged=t.ragged, ) y = tree.map_structure(append_batch_axis, y) @@ -1078,7 +1081,31 @@ def cond(pred, true_fn, false_fn): return Cond()(pred, true_fn, false_fn) -# TODO: also create an Op subclass VectorizedMap. +class VectorizedMap(Operation): + def __init__(self, function, *, name=None): + super().__init__(name=name) + self.function = function + + def call(self, elements): + return backend.core.vectorized_map(self.function, elements) + + def compute_output_spec(self, elements): + x = tree.map_structure(lambda t: t[0], elements) + n = tree.flatten(elements)[0].shape[0] + y = backend.compute_output_spec(self.function, x) + + def append_batch_axis(t): + return KerasTensor( + shape=(n,) + t.shape, + dtype=t.dtype, + sparse=t.sparse, + ragged=t.ragged, + ) + + y = tree.map_structure(append_batch_axis, y) + return y + + @keras_export("keras.ops.vectorized_map") def vectorized_map(function, elements): """Parallel map of `function` on axis 0 of tensor(s) `elements`. @@ -1109,6 +1136,8 @@ def vectorized_map(function, elements): In this case, `function` is expected to take as input a single list of tensor arguments. """ + if any_symbolic_tensors((elements,)): + return VectorizedMap(function)(elements) return backend.core.vectorized_map(function, elements) diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index 2d4270ca818..820c2c73f3e 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -56,17 +56,24 @@ def test_map(self): def f(x): return x**2 - xs = KerasTensor((None,)) - self.assertEqual(core.map(f, xs).shape, (None,)) + xs = KerasTensor((None, 5)) + self.assertEqual(core.map(f, xs).shape, (None, 5)) # Test nested output def f2(x): return {"a": x**2, "b": x * 10} - xs = KerasTensor((None,)) + xs = KerasTensor((None, 5)) ys = core.map(f2, xs) - self.assertEqual(ys["a"].shape, (None,)) - self.assertEqual(ys["b"].shape, (None,)) + self.assertEqual(ys["a"].shape, (None, 5)) + self.assertEqual(ys["b"].shape, (None, 5)) + + # Test nested input + def f3(x): + return x[0] + x[1] + + xs = (KerasTensor((None, 5)), KerasTensor((None, 5))) + self.assertEqual(core.map(f3, xs).shape, (None, 5)) def test_saturate_cast(self): x = KerasTensor((3, 5, None), dtype="float32") @@ -125,6 +132,29 @@ def fn(x, y): self.assertEqual(result[0].shape, (None,)) self.assertEqual(result[1].shape, (None,)) + def test_vectorized_map(self): + def f(x): + return x**2 + + xs = KerasTensor((None, 5)) + self.assertEqual(core.vectorized_map(f, xs).shape, (None, 5)) + + # Test nested output + def f2(x): + return {"a": x**2, "b": x * 10} + + xs = KerasTensor((None, 5)) + ys = core.vectorized_map(f2, xs) + self.assertEqual(ys["a"].shape, (None, 5)) + self.assertEqual(ys["b"].shape, (None, 5)) + + # Test nested input + def f3(x): + return x[0] + x[1] + + xs = (KerasTensor((None, 5)), KerasTensor((None, 5))) + self.assertEqual(core.vectorized_map(f3, xs).shape, (None, 5)) + def test_while_loop(self): def cond(args): return tree.flatten(args)[0] < 10 @@ -203,18 +233,25 @@ def test_map(self): def f(x): return x**2 - xs = KerasTensor((6,)) + xs = KerasTensor((6, 5)) ys = core.map(f, xs) - self.assertEqual(ys.shape, (6,)) + self.assertEqual(ys.shape, (6, 5)) # Test nested output def f2(x): return {"a": x**2, "b": x * 10} - xs = KerasTensor((6,)) + xs = KerasTensor((6, 5)) ys = core.map(f2, xs) - self.assertEqual(ys["a"].shape, (6,)) - self.assertEqual(ys["b"].shape, (6,)) + self.assertEqual(ys["a"].shape, (6, 5)) + self.assertEqual(ys["b"].shape, (6, 5)) + + # Test nested input + def f3(x): + return x[0] + x[1] + + xs = (KerasTensor((6, 5)), KerasTensor((6, 5))) + self.assertEqual(core.map(f3, xs).shape, (6, 5)) def test_saturate_cast(self): x = KerasTensor((3, 5, 7), dtype="float32") @@ -307,6 +344,30 @@ def fn(x, y): self.assertEqual(core.switch(index, [fn], x, y)[0].shape, (5,)) self.assertEqual(core.switch(index, [fn], x, y)[1].shape, (2,)) + def test_vectorized_map(self): + def f(x): + return x**2 + + xs = KerasTensor((6, 5)) + ys = core.vectorized_map(f, xs) + self.assertEqual(ys.shape, (6, 5)) + + # Test nested output + def f2(x): + return {"a": x**2, "b": x * 10} + + xs = KerasTensor((6, 5)) + ys = core.vectorized_map(f2, xs) + self.assertEqual(ys["a"].shape, (6, 5)) + self.assertEqual(ys["b"].shape, (6, 5)) + + # Test nested input + def f3(x): + return x[0] + x[1] + + xs = (KerasTensor((6, 5)), KerasTensor((6, 5))) + self.assertEqual(core.vectorized_map(f3, xs).shape, (6, 5)) + def test_while_loop(self): def cond(args): return tree.flatten(args)[0] < 10 diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index 08cfc73644c..9fe5f3fcafb 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -2340,6 +2340,14 @@ class Deg2rad(Operation): def call(self, x): return backend.numpy.deg2rad(x) + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(x.dtype) + if dtype in ["int64", "float64"]: + dtype = "float64" + elif dtype not in ["bfloat16", "float16"]: + dtype = backend.floatx() + return KerasTensor(x.shape, dtype) + @keras_export(["keras.ops.deg2rad", "keras.ops.numpy.deg2rad"]) def deg2rad(x): diff --git a/keras/src/ops/ops_test.py b/keras/src/ops/ops_test.py index a97567b67cf..f51510cd57c 100644 --- a/keras/src/ops/ops_test.py +++ b/keras/src/ops/ops_test.py @@ -182,6 +182,7 @@ def test_class_function_consistency(self, module_name): # Check order of parameters. if name in ( "fori_loop", + "vectorized_map", "while_loop", "batch_normalization", "dot_product_attention", @@ -224,6 +225,16 @@ def test_class_function_consistency(self, module_name): f"function `{name}` and op class `{op_class.__name__}`", ) + # ==== Check compute_output_spec is implement ==== + # - op class should override Operation's `compute_output_spec` + self.assertTrue( + hasattr(op_class, "compute_output_spec") + and op_class.compute_output_spec + is not Operation.compute_output_spec, + f"Op class `{op_class.__name__}` should override " + "`compute_output_spec`", + ) + @parameterized.named_parameters(named_product(module_name=OPS_MODULES)) def test_backend_consistency(self, module_name): ops_module = getattr(ops, module_name)