Skip to content

Commit 9b0f72e

Browse files
Merge pull request #77 from brandonwillard/fix-tf-name-and-dtype-inference
Fix TF tensor name construction and add NodeDef dtype information
2 parents 8077d64 + d644e0b commit 9b0f72e

File tree

2 files changed

+42
-28
lines changed

2 files changed

+42
-28
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from inspect import Parameter, Signature
1010

1111
from collections import OrderedDict
12-
from collections.abc import Sequence
1312

1413
from functools import partial
1514

@@ -220,11 +219,11 @@ def __init__(self, obj=None):
220219
super().__init__(obj=obj)
221220
self._apply_func_sig, self._apply_func = op_def_lib.get_op_info(obj)
222221

223-
def out_meta_types(self, inputs=None):
222+
def out_meta_types(self, inputs=None, node_def=None):
224223
def _convert_outputs(o):
225-
if o.type_attr == "T":
226-
return (TFlowMetaTensor, var())
227-
elif o.type_attr == "dtype":
224+
if o.type_attr == "T" and node_def:
225+
return (TFlowMetaTensor, node_def.attr.get("T", var()))
226+
elif o.type_attr == "dtype" and inputs:
228227
return (TFlowMetaTensor, inputs.get("dtype", var()))
229228
else:
230229
return (TFlowMetaTensor, var())
@@ -284,7 +283,6 @@ def __call__(self, *args, **kwargs):
284283
apply_arguments.get(i.name) for i in self.obj.input_arg if i.name in apply_arguments
285284
)
286285

287-
# Get the `OpDef`-instantiating parameters and call them a "node_def".
288286
node_attr = {a.name: apply_arguments.get(a.name, a) for a in self.obj.attr}
289287

290288
op_name = op_kwargs.get("name", self.obj.name)
@@ -346,6 +344,8 @@ def _protobuf_convert(cls, k, v):
346344
return metatize(tensor_shape.as_shape(v.shape))
347345
elif k == "dtype":
348346
return tf.as_dtype(v.type).name
347+
elif k == "T":
348+
return tf.as_dtype(v.type).name
349349
elif k == "value":
350350
return tensor_util.MakeNdarray(v.tensor)
351351
else:
@@ -364,22 +364,17 @@ def __init__(self, op, name, attr, obj=None):
364364
self.name = name if isvar(name) else str(name)
365365

366366
if not isvar(attr):
367-
# We want to limit the attributes we'll consider to those that show
368-
# up in an OpDef function's signature (e.g. ignore info about
369-
# permissible types).
370367
opdef_sig, _ = op_def_lib.get_op_info(self.op)
371-
op_param_names = opdef_sig.parameters.keys()
372-
373368
_attr = dict()
369+
374370
for k, v in attr.items():
375371
if isinstance(v, Message):
376372
try:
377373
v = self._protobuf_convert(k, v)
378374
except TypeError:
379-
continue
375+
v = var()
380376

381-
if k != "T" and k in op_param_names:
382-
_attr[k] = v
377+
_attr[k] = v
383378

384379
self.attr = _attr
385380
else:
@@ -532,11 +527,12 @@ def outputs(self):
532527
else:
533528

534529
apply_arguments = self.op_def.input_args(*self.inputs, **self.node_def.attr)
535-
out_types_mt = self.op_def.out_meta_types(inputs=apply_arguments)
530+
out_types_mt = self.op_def.out_meta_types(
531+
inputs=apply_arguments, node_def=self.node_def
532+
)
536533

537534
mt_outs = tuple(
538-
o_type(self, i, var() if o_dtype is None else o_dtype)
539-
for i, (o_type, o_dtype) in enumerate(out_types_mt)
535+
o_type(self, i, o_dtype) for i, (o_type, o_dtype) in enumerate(out_types_mt)
540536
)
541537

542538
self._outputs = mt_outs
@@ -574,7 +570,15 @@ def reify(self):
574570
if isvar(self.node_def):
575571
return self
576572

577-
op_attrs, op_attrs_unreified = meta_reify_iter(self.node_def.attr)
573+
op_attrs, op_attrs_unreified = meta_reify_iter(
574+
# Only use NodeDef attrs that appear in the OpDef's call signature.
575+
# Other NodeDef attrs, like dtype and shape, can be computed.
576+
{
577+
k: v
578+
for k, v in self.node_def.attr.items()
579+
if k in self.op_def._apply_func_sig.parameters
580+
}
581+
)
578582

579583
if not (op_inputs_unreified or op_attrs_unreified or MetaSymbol.is_meta(self.name)):
580584

@@ -587,6 +591,8 @@ def reify(self):
587591
tf_out = self.op_def._apply_func(**apply_arguments)
588592
op_tf = tf_out.op
589593

594+
# TODO: Update NodeDef attrs?
595+
590596
assert op_tf is not None
591597
self._obj = op_tf
592598
return self.obj
@@ -623,14 +629,8 @@ def name(self):
623629

624630
if self.obj is not None and not isinstance(self.obj, Var):
625631
name = self.obj.name
626-
elif (
627-
self.op is not None
628-
and not isvar(self.op)
629-
and not isvar(self.op.name)
630-
and not isinstance(self.op.outputs, Sequence)
631-
):
632-
out_num = self.op.outputs.index(self)
633-
name = f"{self.op.name}:{out_num}"
632+
elif isinstance(getattr(self.op, "name", None), str) and not isvar(self.value_index):
633+
name = f"{self.op.name}:{self.value_index}"
634634
else:
635635
name = var()
636636

tests/tensorflow/test_meta.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
If you're debugging/running tests manually, it might help to simply
33
disable eager execution entirely:
44
5-
> tf.compat.v1.disable_eager_execution()
5+
tf.compat.v1.disable_eager_execution()
66
"""
77
import pytest
88
import numpy as np
@@ -172,6 +172,21 @@ def test_meta_basic():
172172
assert a_mt.shape.ndims == 2
173173
assert a_mt.shape == TFlowMetaTensorShape([1, 2])
174174

175+
# Make sure that names are properly inferred when there are no base objects
176+
# to reference
177+
with tf.Graph().as_default():
178+
one_mt = mt(1.0)
179+
log_mt = mt.log(one_mt)
180+
assert log_mt.name == 'Log:0'
181+
assert log_mt.dtype == tf.float32
182+
assert log_mt.op.outputs[0].dtype == tf.float32
183+
184+
log_mt._name = None
185+
one_mt._obj = None
186+
log_mt._obj = None
187+
assert log_mt.dtype == tf.float32
188+
assert log_mt.name == 'Log:0'
189+
175190

176191
@pytest.mark.usefixtures("run_with_tensorflow")
177192
@run_in_graph_mode
@@ -394,7 +409,6 @@ def test_nodedef():
394409

395410
assert 'compute_uv' in node_def_mt.attr
396411
assert 'full_matrices' in node_def_mt.attr
397-
assert 'T' not in node_def_mt.attr
398412

399413
# Some outputs use nodedef information; let's test those.
400414
norm_rv = mt.RandomStandardNormal(mean=0, stddev=1, shape=(1000,), dtype=tf.float32, name=var())

0 commit comments

Comments
 (0)