Skip to content

Commit a4e43ce

Browse files
Merge pull request #78 from brandonwillard/fix-tf-lvar-op-name
Return lvar instead of str when TF meta Operator name is an lvar
2 parents 9b0f72e + cbf2551 commit a4e43ce

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def name(self):
499499
if isvar(self.node_def):
500500
self._name = var()
501501
else:
502-
self._name = str(self.node_def.name)
502+
self._name = self.node_def.name
503503

504504
return self._name
505505

@@ -606,6 +606,8 @@ class TFlowMetaTensor(TFlowMetaSymbol, MetaVariable):
606606

607607
def __init__(self, op, value_index, dtype, obj=None):
608608
self.op = metatize(op)
609+
# TODO: Sync this value with `op.node_def.attr['dtype']` and/or
610+
# `op.node_def.attr['T']`?
609611
self.dtype = dtype
610612
self.value_index = value_index
611613
super().__init__(obj=obj)

tests/tensorflow/test_meta.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ def test_meta_basic():
187187
assert log_mt.dtype == tf.float32
188188
assert log_mt.name == 'Log:0'
189189

190+
log_mt = mt.log(var(), name=var())
191+
assert isvar(log_mt.name)
192+
190193

191194
@pytest.mark.usefixtures("run_with_tensorflow")
192195
@run_in_graph_mode
@@ -201,6 +204,8 @@ def test_meta_Op():
201204
# reasonable. Likewise, the `0` gets converted, but it probably shouldn't be.
202205
test_op = TFlowMetaOp(mt.Concat, var(), [[t1_tf, t2_tf], 0])
203206

207+
assert isvar(test_op.name)
208+
204209
# Make sure we converted lists to tuples
205210
assert isinstance(test_op.inputs, tuple)
206211
assert isinstance(test_op.inputs[0], tuple)

0 commit comments

Comments
 (0)