9
9
from inspect import Parameter , Signature
10
10
11
11
from collections import OrderedDict
12
- from collections .abc import Sequence
13
12
14
13
from functools import partial
15
14
@@ -220,11 +219,11 @@ def __init__(self, obj=None):
220
219
super ().__init__ (obj = obj )
221
220
self ._apply_func_sig , self ._apply_func = op_def_lib .get_op_info (obj )
222
221
223
- def out_meta_types (self , inputs = None ):
222
+ def out_meta_types (self , inputs = None , node_def = None ):
224
223
def _convert_outputs (o ):
225
- if o .type_attr == "T" :
226
- return (TFlowMetaTensor , var ())
227
- elif o .type_attr == "dtype" :
224
+ if o .type_attr == "T" and node_def :
225
+ return (TFlowMetaTensor , node_def . attr . get ( "T" , var () ))
226
+ elif o .type_attr == "dtype" and inputs :
228
227
return (TFlowMetaTensor , inputs .get ("dtype" , var ()))
229
228
else :
230
229
return (TFlowMetaTensor , var ())
@@ -284,7 +283,6 @@ def __call__(self, *args, **kwargs):
284
283
apply_arguments .get (i .name ) for i in self .obj .input_arg if i .name in apply_arguments
285
284
)
286
285
287
- # Get the `OpDef`-instantiating parameters and call them a "node_def".
288
286
node_attr = {a .name : apply_arguments .get (a .name , a ) for a in self .obj .attr }
289
287
290
288
op_name = op_kwargs .get ("name" , self .obj .name )
@@ -346,6 +344,8 @@ def _protobuf_convert(cls, k, v):
346
344
return metatize (tensor_shape .as_shape (v .shape ))
347
345
elif k == "dtype" :
348
346
return tf .as_dtype (v .type ).name
347
+ elif k == "T" :
348
+ return tf .as_dtype (v .type ).name
349
349
elif k == "value" :
350
350
return tensor_util .MakeNdarray (v .tensor )
351
351
else :
@@ -364,22 +364,17 @@ def __init__(self, op, name, attr, obj=None):
364
364
self .name = name if isvar (name ) else str (name )
365
365
366
366
if not isvar (attr ):
367
- # We want to limit the attributes we'll consider to those that show
368
- # up in an OpDef function's signature (e.g. ignore info about
369
- # permissible types).
370
367
opdef_sig , _ = op_def_lib .get_op_info (self .op )
371
- op_param_names = opdef_sig .parameters .keys ()
372
-
373
368
_attr = dict ()
369
+
374
370
for k , v in attr .items ():
375
371
if isinstance (v , Message ):
376
372
try :
377
373
v = self ._protobuf_convert (k , v )
378
374
except TypeError :
379
- continue
375
+ v = var ()
380
376
381
- if k != "T" and k in op_param_names :
382
- _attr [k ] = v
377
+ _attr [k ] = v
383
378
384
379
self .attr = _attr
385
380
else :
@@ -532,11 +527,12 @@ def outputs(self):
532
527
else :
533
528
534
529
apply_arguments = self .op_def .input_args (* self .inputs , ** self .node_def .attr )
535
- out_types_mt = self .op_def .out_meta_types (inputs = apply_arguments )
530
+ out_types_mt = self .op_def .out_meta_types (
531
+ inputs = apply_arguments , node_def = self .node_def
532
+ )
536
533
537
534
mt_outs = tuple (
538
- o_type (self , i , var () if o_dtype is None else o_dtype )
539
- for i , (o_type , o_dtype ) in enumerate (out_types_mt )
535
+ o_type (self , i , o_dtype ) for i , (o_type , o_dtype ) in enumerate (out_types_mt )
540
536
)
541
537
542
538
self ._outputs = mt_outs
@@ -574,7 +570,15 @@ def reify(self):
574
570
if isvar (self .node_def ):
575
571
return self
576
572
577
- op_attrs , op_attrs_unreified = meta_reify_iter (self .node_def .attr )
573
+ op_attrs , op_attrs_unreified = meta_reify_iter (
574
+ # Only use NodeDef attrs that appear in the OpDef's call signature.
575
+ # Other NodeDef attrs, like dtype and shape, can be computed.
576
+ {
577
+ k : v
578
+ for k , v in self .node_def .attr .items ()
579
+ if k in self .op_def ._apply_func_sig .parameters
580
+ }
581
+ )
578
582
579
583
if not (op_inputs_unreified or op_attrs_unreified or MetaSymbol .is_meta (self .name )):
580
584
@@ -587,6 +591,8 @@ def reify(self):
587
591
tf_out = self .op_def ._apply_func (** apply_arguments )
588
592
op_tf = tf_out .op
589
593
594
+ # TODO: Update NodeDef attrs?
595
+
590
596
assert op_tf is not None
591
597
self ._obj = op_tf
592
598
return self .obj
@@ -623,14 +629,8 @@ def name(self):
623
629
624
630
if self .obj is not None and not isinstance (self .obj , Var ):
625
631
name = self .obj .name
626
- elif (
627
- self .op is not None
628
- and not isvar (self .op )
629
- and not isvar (self .op .name )
630
- and not isinstance (self .op .outputs , Sequence )
631
- ):
632
- out_num = self .op .outputs .index (self )
633
- name = f"{ self .op .name } :{ out_num } "
632
+ elif isinstance (getattr (self .op , "name" , None ), str ) and not isvar (self .value_index ):
633
+ name = f"{ self .op .name } :{ self .value_index } "
634
634
else :
635
635
name = var ()
636
636
0 commit comments