Skip to content

Commit dff6ff4

Browse files
Merge pull request #79 from brandonwillard/overload-tf-arithmetic-methods
Add overloaded Python arithmetic methods to meta TF tensor
2 parents a4e43ce + 229c075 commit dff6ff4

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

symbolic_pymc/tensorflow/meta.py

+37
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,43 @@ def reify(self):
690690

691691
return self
692692

693+
def __truediv__(self, y):
694+
# TODO: TF performs some dtype logic (using `dtype.base_dtype`) and casting here.
695+
return mt.realdiv(self, y, name="truediv")
696+
697+
def __rtruediv__(self, x):
698+
# TODO: TF performs some dtype logic (using `dtype.base_dtype`) and casting here.
699+
return mt.realdiv(x, self, name="truediv")
700+
701+
def __add__(self, y):
702+
# TODO: If `self.dtype == tf.dtypes.string`, use `mt.add`
703+
return mt.addv2(self, y, name="add")
704+
705+
def __radd__(self, x):
706+
# TODO: If `x.dtype == tf.dtypes.string`, use `mt.add`
707+
return mt.addv2(x, self, name="add")
708+
709+
def __sub__(self, y):
710+
return mt.sub(self, y, name="sub")
711+
712+
def __rsub__(self, x):
713+
return mt.sub(x, self, name="sub")
714+
715+
def __mul__(self, y):
716+
return mt.mul(self, y, name="mul")
717+
718+
def __rmul__(self, x):
719+
return mt.mul(x, self, name="mul")
720+
721+
def __abs__(self):
722+
return mt.abs(self, name="Abs")
723+
724+
def __pow__(self, y):
725+
return mt.pow(self, y, name="pow")
726+
727+
def __neg__(self):
728+
return mt.neg(self, name="Neg")
729+
693730

694731
class TFlowMetaTensorShape(TFlowMetaSymbol):
695732
base = tf.TensorShape

tests/tensorflow/test_meta.py

+59
Original file line numberDiff line numberDiff line change
@@ -456,3 +456,62 @@ def test_opdef_func():
456456

457457
with tf.compat.v1.Session() as sess:
458458
assert sum_tf.eval() == np.r_[3]
459+
460+
461+
@pytest.mark.usefixtures("run_with_tensorflow")
462+
@run_in_graph_mode
463+
def test_tensor_ops():
464+
465+
with tf.Graph().as_default():
466+
x_tf = tf.compat.v1.placeholder('float')
467+
y_tf = tf.compat.v1.placeholder('float')
468+
469+
mul_tf = x_tf * y_tf
470+
rmul_tf = 1.0 * x_tf
471+
div_tf = x_tf / y_tf
472+
rdiv_tf = 1.0 / y_tf
473+
add_tf = x_tf + y_tf
474+
radd_tf = 1.0 + y_tf
475+
sub_tf = x_tf - y_tf
476+
rsub_tf = 1.0 - y_tf
477+
pow_tf = x_tf**y_tf
478+
neg_tf = -x_tf
479+
abs_tf = abs(x_tf)
480+
481+
with tf.Graph().as_default():
482+
x_mt = mt.Placeholder('float')
483+
y_mt = mt.Placeholder('float')
484+
485+
mul_mt = x_mt * y_mt
486+
assert mul_mt.name == mul_tf.name
487+
assert mul_mt.op.type == mul_tf.op.type
488+
rmul_mt = 1.0 * x_mt
489+
assert rmul_mt.name == rmul_tf.name
490+
assert rmul_mt.op.type == rmul_tf.op.type
491+
div_mt = x_mt / y_mt
492+
assert div_mt.name == div_tf.name
493+
assert div_mt.op.type == div_tf.op.type
494+
rdiv_mt = 1.0 / y_mt
495+
assert rdiv_mt.name == rdiv_tf.name
496+
assert rdiv_mt.op.type == rdiv_tf.op.type
497+
add_mt = x_mt + y_mt
498+
assert add_mt.name == add_tf.name
499+
assert add_mt.op.type == add_tf.op.type
500+
radd_mt = 1.0 + y_mt
501+
assert radd_mt.name == radd_tf.name
502+
assert radd_mt.op.type == radd_tf.op.type
503+
sub_mt = x_mt - y_mt
504+
assert sub_mt.name == sub_tf.name
505+
assert sub_mt.op.type == sub_tf.op.type
506+
rsub_mt = 1.0 - y_mt
507+
assert rsub_mt.name == rsub_tf.name
508+
assert rsub_mt.op.type == rsub_tf.op.type
509+
pow_mt = x_mt**y_mt
510+
assert pow_mt.name == pow_tf.name
511+
assert pow_mt.op.type == pow_tf.op.type
512+
neg_mt = -x_mt
513+
assert neg_mt.name == neg_tf.name
514+
assert neg_mt.op.type == neg_tf.op.type
515+
abs_mt = abs(x_mt)
516+
assert abs_mt.name == abs_tf.name
517+
assert abs_mt.op.type == abs_tf.op.type

0 commit comments

Comments
 (0)