Skip to content

Commit 8bf6a58

Browse files
authored
Add VectorizedMap op class. (#21516)
Also: - fix `Map.compute_output_spec` so that it handles nested inputs - test `map` op with nested inputs - added missing `Deg2Rad.compute_output_spec` - added test verifying that all ops implement `compute_output_spec`.
1 parent 7cb0e48 commit 8bf6a58

File tree

4 files changed

+124
-15
lines changed

4 files changed

+124
-15
lines changed

keras/src/ops/core.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@ def call(self, f, xs):
1616
return backend.core.map(f, xs)
1717

1818
def compute_output_spec(self, f, xs):
19-
x = xs[0]
20-
n = xs.shape[0]
19+
x = tree.map_structure(lambda t: t[0], xs)
20+
n = tree.flatten(xs)[0].shape[0]
2121
y = backend.compute_output_spec(f, x)
2222

23-
def append_batch_axis(x):
23+
def append_batch_axis(t):
2424
return KerasTensor(
25-
shape=(n,) + x.shape, dtype=x.dtype, sparse=x.sparse
25+
shape=(n,) + t.shape,
26+
dtype=t.dtype,
27+
sparse=t.sparse,
28+
ragged=t.ragged,
2629
)
2730

2831
y = tree.map_structure(append_batch_axis, y)
@@ -1078,7 +1081,31 @@ def cond(pred, true_fn, false_fn):
10781081
return Cond()(pred, true_fn, false_fn)
10791082

10801083

1081-
# TODO: also create an Op subclass VectorizedMap.
1084+
class VectorizedMap(Operation):
1085+
def __init__(self, function, *, name=None):
1086+
super().__init__(name=name)
1087+
self.function = function
1088+
1089+
def call(self, elements):
1090+
return backend.core.vectorized_map(self.function, elements)
1091+
1092+
def compute_output_spec(self, elements):
1093+
x = tree.map_structure(lambda t: t[0], elements)
1094+
n = tree.flatten(elements)[0].shape[0]
1095+
y = backend.compute_output_spec(self.function, x)
1096+
1097+
def append_batch_axis(t):
1098+
return KerasTensor(
1099+
shape=(n,) + t.shape,
1100+
dtype=t.dtype,
1101+
sparse=t.sparse,
1102+
ragged=t.ragged,
1103+
)
1104+
1105+
y = tree.map_structure(append_batch_axis, y)
1106+
return y
1107+
1108+
10821109
@keras_export("keras.ops.vectorized_map")
10831110
def vectorized_map(function, elements):
10841111
"""Parallel map of `function` on axis 0 of tensor(s) `elements`.
@@ -1109,6 +1136,8 @@ def vectorized_map(function, elements):
11091136
In this case, `function` is expected to take as input
11101137
a single list of tensor arguments.
11111138
"""
1139+
if any_symbolic_tensors((elements,)):
1140+
return VectorizedMap(function)(elements)
11121141
return backend.core.vectorized_map(function, elements)
11131142

11141143

keras/src/ops/core_test.py

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,24 @@ def test_map(self):
5656
def f(x):
5757
return x**2
5858

59-
xs = KerasTensor((None,))
60-
self.assertEqual(core.map(f, xs).shape, (None,))
59+
xs = KerasTensor((None, 5))
60+
self.assertEqual(core.map(f, xs).shape, (None, 5))
6161

6262
# Test nested output
6363
def f2(x):
6464
return {"a": x**2, "b": x * 10}
6565

66-
xs = KerasTensor((None,))
66+
xs = KerasTensor((None, 5))
6767
ys = core.map(f2, xs)
68-
self.assertEqual(ys["a"].shape, (None,))
69-
self.assertEqual(ys["b"].shape, (None,))
68+
self.assertEqual(ys["a"].shape, (None, 5))
69+
self.assertEqual(ys["b"].shape, (None, 5))
70+
71+
# Test nested input
72+
def f3(x):
73+
return x[0] + x[1]
74+
75+
xs = (KerasTensor((None, 5)), KerasTensor((None, 5)))
76+
self.assertEqual(core.map(f3, xs).shape, (None, 5))
7077

7178
def test_saturate_cast(self):
7279
x = KerasTensor((3, 5, None), dtype="float32")
@@ -125,6 +132,29 @@ def fn(x, y):
125132
self.assertEqual(result[0].shape, (None,))
126133
self.assertEqual(result[1].shape, (None,))
127134

135+
def test_vectorized_map(self):
136+
def f(x):
137+
return x**2
138+
139+
xs = KerasTensor((None, 5))
140+
self.assertEqual(core.vectorized_map(f, xs).shape, (None, 5))
141+
142+
# Test nested output
143+
def f2(x):
144+
return {"a": x**2, "b": x * 10}
145+
146+
xs = KerasTensor((None, 5))
147+
ys = core.vectorized_map(f2, xs)
148+
self.assertEqual(ys["a"].shape, (None, 5))
149+
self.assertEqual(ys["b"].shape, (None, 5))
150+
151+
# Test nested input
152+
def f3(x):
153+
return x[0] + x[1]
154+
155+
xs = (KerasTensor((None, 5)), KerasTensor((None, 5)))
156+
self.assertEqual(core.vectorized_map(f3, xs).shape, (None, 5))
157+
128158
def test_while_loop(self):
129159
def cond(args):
130160
return tree.flatten(args)[0] < 10
@@ -203,18 +233,25 @@ def test_map(self):
203233
def f(x):
204234
return x**2
205235

206-
xs = KerasTensor((6,))
236+
xs = KerasTensor((6, 5))
207237
ys = core.map(f, xs)
208-
self.assertEqual(ys.shape, (6,))
238+
self.assertEqual(ys.shape, (6, 5))
209239

210240
# Test nested output
211241
def f2(x):
212242
return {"a": x**2, "b": x * 10}
213243

214-
xs = KerasTensor((6,))
244+
xs = KerasTensor((6, 5))
215245
ys = core.map(f2, xs)
216-
self.assertEqual(ys["a"].shape, (6,))
217-
self.assertEqual(ys["b"].shape, (6,))
246+
self.assertEqual(ys["a"].shape, (6, 5))
247+
self.assertEqual(ys["b"].shape, (6, 5))
248+
249+
# Test nested input
250+
def f3(x):
251+
return x[0] + x[1]
252+
253+
xs = (KerasTensor((6, 5)), KerasTensor((6, 5)))
254+
self.assertEqual(core.map(f3, xs).shape, (6, 5))
218255

219256
def test_saturate_cast(self):
220257
x = KerasTensor((3, 5, 7), dtype="float32")
@@ -307,6 +344,30 @@ def fn(x, y):
307344
self.assertEqual(core.switch(index, [fn], x, y)[0].shape, (5,))
308345
self.assertEqual(core.switch(index, [fn], x, y)[1].shape, (2,))
309346

347+
def test_vectorized_map(self):
348+
def f(x):
349+
return x**2
350+
351+
xs = KerasTensor((6, 5))
352+
ys = core.vectorized_map(f, xs)
353+
self.assertEqual(ys.shape, (6, 5))
354+
355+
# Test nested output
356+
def f2(x):
357+
return {"a": x**2, "b": x * 10}
358+
359+
xs = KerasTensor((6, 5))
360+
ys = core.vectorized_map(f2, xs)
361+
self.assertEqual(ys["a"].shape, (6, 5))
362+
self.assertEqual(ys["b"].shape, (6, 5))
363+
364+
# Test nested input
365+
def f3(x):
366+
return x[0] + x[1]
367+
368+
xs = (KerasTensor((6, 5)), KerasTensor((6, 5)))
369+
self.assertEqual(core.vectorized_map(f3, xs).shape, (6, 5))
370+
310371
def test_while_loop(self):
311372
def cond(args):
312373
return tree.flatten(args)[0] < 10

keras/src/ops/numpy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2340,6 +2340,14 @@ class Deg2rad(Operation):
23402340
def call(self, x):
23412341
return backend.numpy.deg2rad(x)
23422342

2343+
def compute_output_spec(self, x):
2344+
dtype = backend.standardize_dtype(x.dtype)
2345+
if dtype in ["int64", "float64"]:
2346+
dtype = "float64"
2347+
elif dtype not in ["bfloat16", "float16"]:
2348+
dtype = backend.floatx()
2349+
return KerasTensor(x.shape, dtype)
2350+
23432351

23442352
@keras_export(["keras.ops.deg2rad", "keras.ops.numpy.deg2rad"])
23452353
def deg2rad(x):

keras/src/ops/ops_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def test_class_function_consistency(self, module_name):
182182
# Check order of parameters.
183183
if name in (
184184
"fori_loop",
185+
"vectorized_map",
185186
"while_loop",
186187
"batch_normalization",
187188
"dot_product_attention",
@@ -224,6 +225,16 @@ def test_class_function_consistency(self, module_name):
224225
f"function `{name}` and op class `{op_class.__name__}`",
225226
)
226227

228+
# ==== Check compute_output_spec is implement ====
229+
# - op class should override Operation's `compute_output_spec`
230+
self.assertTrue(
231+
hasattr(op_class, "compute_output_spec")
232+
and op_class.compute_output_spec
233+
is not Operation.compute_output_spec,
234+
f"Op class `{op_class.__name__}` should override "
235+
"`compute_output_spec`",
236+
)
237+
227238
@parameterized.named_parameters(named_product(module_name=OPS_MODULES))
228239
def test_backend_consistency(self, module_name):
229240
ops_module = getattr(ops, module_name)

0 commit comments

Comments
 (0)