@@ -139,28 +139,13 @@ def __init__(self, inputs, outputs, name=None, **kwargs):
139
139
self .trainable = trainable
140
140
141
141
self ._layers = self .layers
142
-
143
142
self .build (None )
144
143
# We will convert directly (to the correct dtype per input).
145
144
self ._convert_input_args = False
146
145
self ._allow_non_tensor_positional_args = True
147
146
output_layers = [x ._keras_history [0 ] for x in self .outputs ]
148
147
self .output_names = [x .name for x in output_layers ]
149
148
150
- def _setup_nnx_op_mapping (self ):
151
- """Setup operation mapping for NNX"""
152
- # Create a mapping from operation id to operation instance
153
- self ._nnx_op_mapping = {}
154
-
155
- # Store operations as direct attributes for NNX traversal
156
- for i , operation in enumerate (self ._operations ):
157
- if isinstance (operation , Layer ):
158
- # Store operation as direct attribute with unique name
159
- attr_name = f"_layer_{ i } _{ operation .name } "
160
- setattr (self , attr_name , operation )
161
- # Map the operation id to this operation instance
162
- self ._nnx_op_mapping [id (operation )] = operation
163
-
164
149
def _lock_state (self ):
165
150
# Unlike other layers, we allow Functional state to be mutable after
166
151
# build. E.g. to attach a layer to a model that is not part of the
@@ -186,6 +171,7 @@ def layers(self, _):
186
171
)
187
172
188
173
def call (self , inputs , training = None , mask = None , ** kwargs ):
174
+ # Add support for training, masking
189
175
inputs = self ._standardize_inputs (inputs )
190
176
if mask is None :
191
177
masks = [None ] * len (inputs )
@@ -407,7 +393,7 @@ def get_config(self):
407
393
# the author of the subclassed network).
408
394
return Model .get_config (self )
409
395
410
- cfg = {
396
+ config = {
411
397
"name" : self .name ,
412
398
"trainable" : self .trainable ,
413
399
}
@@ -454,7 +440,7 @@ def get_config(self):
454
440
layer_config ["name" ] = operation .name
455
441
layer_config ["inbound_nodes" ] = filtered_inbound_nodes
456
442
layer_configs .append (layer_config )
457
- cfg ["layers" ] = layer_configs
443
+ config ["layers" ] = layer_configs
458
444
459
445
# Gather info about inputs and outputs.
460
446
def get_tensor_config (tensor ):
@@ -469,9 +455,9 @@ def get_tensor_config(tensor):
469
455
def map_tensors (tensors ):
470
456
return tree .map_structure (get_tensor_config , tensors )
471
457
472
- cfg ["input_layers" ] = map_tensors (self ._inputs_struct )
473
- cfg ["output_layers" ] = map_tensors (self ._outputs_struct )
474
- return copy .deepcopy (cfg )
458
+ config ["input_layers" ] = map_tensors (self ._inputs_struct )
459
+ config ["output_layers" ] = map_tensors (self ._outputs_struct )
460
+ return copy .deepcopy (config )
475
461
476
462
477
463
def functional_from_config (cls , config , custom_objects = None ):
0 commit comments