Skip to content

Commit f84cc1e

Browse files
revert functional.py
1 parent 772929c commit f84cc1e

File tree

1 file changed

+6
-20
lines changed

1 file changed

+6
-20
lines changed

keras/src/models/functional.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -139,28 +139,13 @@ def __init__(self, inputs, outputs, name=None, **kwargs):
139139
self.trainable = trainable
140140

141141
self._layers = self.layers
142-
143142
self.build(None)
144143
# We will convert directly (to the correct dtype per input).
145144
self._convert_input_args = False
146145
self._allow_non_tensor_positional_args = True
147146
output_layers = [x._keras_history[0] for x in self.outputs]
148147
self.output_names = [x.name for x in output_layers]
149148

150-
def _setup_nnx_op_mapping(self):
151-
"""Setup operation mapping for NNX"""
152-
# Create a mapping from operation id to operation instance
153-
self._nnx_op_mapping = {}
154-
155-
# Store operations as direct attributes for NNX traversal
156-
for i, operation in enumerate(self._operations):
157-
if isinstance(operation, Layer):
158-
# Store operation as direct attribute with unique name
159-
attr_name = f"_layer_{i}_{operation.name}"
160-
setattr(self, attr_name, operation)
161-
# Map the operation id to this operation instance
162-
self._nnx_op_mapping[id(operation)] = operation
163-
164149
def _lock_state(self):
165150
# Unlike other layers, we allow Functional state to be mutable after
166151
# build. E.g. to attach a layer to a model that is not part of the
@@ -186,6 +171,7 @@ def layers(self, _):
186171
)
187172

188173
def call(self, inputs, training=None, mask=None, **kwargs):
174+
# Add support for training, masking
189175
inputs = self._standardize_inputs(inputs)
190176
if mask is None:
191177
masks = [None] * len(inputs)
@@ -407,7 +393,7 @@ def get_config(self):
407393
# the author of the subclassed network).
408394
return Model.get_config(self)
409395

410-
cfg = {
396+
config = {
411397
"name": self.name,
412398
"trainable": self.trainable,
413399
}
@@ -454,7 +440,7 @@ def get_config(self):
454440
layer_config["name"] = operation.name
455441
layer_config["inbound_nodes"] = filtered_inbound_nodes
456442
layer_configs.append(layer_config)
457-
cfg["layers"] = layer_configs
443+
config["layers"] = layer_configs
458444

459445
# Gather info about inputs and outputs.
460446
def get_tensor_config(tensor):
@@ -469,9 +455,9 @@ def get_tensor_config(tensor):
469455
def map_tensors(tensors):
470456
return tree.map_structure(get_tensor_config, tensors)
471457

472-
cfg["input_layers"] = map_tensors(self._inputs_struct)
473-
cfg["output_layers"] = map_tensors(self._outputs_struct)
474-
return copy.deepcopy(cfg)
458+
config["input_layers"] = map_tensors(self._inputs_struct)
459+
config["output_layers"] = map_tensors(self._outputs_struct)
460+
return copy.deepcopy(config)
475461

476462

477463
def functional_from_config(cls, config, custom_objects=None):

0 commit comments

Comments
 (0)