28
28
MathError )
29
29
from brainpy .nn .algorithms .offline import OfflineAlgorithm
30
30
from brainpy .nn .algorithms .online import OnlineAlgorithm
31
- from brainpy .nn .constants import (PASS_SEQUENCE ,
32
- DATA_PASS_FUNC ,
33
- DATA_PASS_TYPES )
31
+ from brainpy .nn .datatypes import (DataType , SingleData , MultipleData )
34
32
from brainpy .nn .graph_flow import (find_senders_and_receivers ,
35
33
find_entries_and_exits ,
36
34
detect_cycle ,
@@ -83,13 +81,13 @@ def feedback(self):
83
81
class Node (Base ):
84
82
"""Basic Node class for neural network building in BrainPy."""
85
83
86
- '''Support multiple types of data pass, including "PASS_SEQUENCE " (by default),
87
- "PASS_NAME_DICT ", "PASS_NODE_DICT" and user-customized type which registered
88
- by ``brainpy.nn.register_data_pass_type()`` function .
84
+ '''Support multiple types of data pass, including "PassOnlyOne " (by default),
85
+ "PassSequence ", "PassNameDict", etc. and user-customized type which inherits
86
+ from basic "SingleData" or "MultipleData" .
89
87
90
88
This setting will change the feedforward/feedback input data which pass into
91
89
the "call()" function and the sizes of the feedforward/feedback input data.'''
92
- data_pass_type = PASS_SEQUENCE
90
+ data_pass = SingleData ()
93
91
94
92
'''Offline fitting method.'''
95
93
offline_fit_by : Union [Callable , OfflineAlgorithm ]
@@ -115,11 +113,10 @@ def __init__(
115
113
self ._trainable = trainable
116
114
self ._state = None # the state of the current node
117
115
self ._fb_output = None # the feedback output of the current node
118
- # data pass function
119
- if self .data_pass_type not in DATA_PASS_FUNC :
120
- raise ValueError (f'Unsupported data pass type { self .data_pass_type } . '
121
- f'Only support { DATA_PASS_TYPES } ' )
122
- self .data_pass_func = DATA_PASS_FUNC [self .data_pass_type ]
116
+ # data pass
117
+ if not isinstance (self .data_pass , DataType ):
118
+ raise ValueError (f'Unsupported data pass type { type (self .data_pass )} . '
119
+ f'Only support { DataType .__class__ } ' )
123
120
124
121
# super initialization
125
122
super (Node , self ).__init__ (name = name )
@@ -129,11 +126,10 @@ def __init__(
129
126
self ._feedforward_shapes = {self .name : (None ,) + tools .to_size (input_shape )}
130
127
131
128
def __repr__ (self ):
132
- name = type (self ).__name__
133
- prefix = ' ' * (len (name ) + 1 )
134
- line1 = f"{ name } (name={ self .name } , forwards={ self .feedforward_shapes } , \n "
135
- line2 = f"{ prefix } feedbacks={ self .feedback_shapes } , output={ self .output_shape } )"
136
- return line1 + line2
129
+ return (f"{ type (self ).__name__ } (name={ self .name } , "
130
+ f"forwards={ self .feedforward_shapes } , "
131
+ f"feedbacks={ self .feedback_shapes } , "
132
+ f"output={ self .output_shape } )" )
137
133
138
134
def __call__ (self , * args , ** kwargs ) -> Tensor :
139
135
"""The main computation function of a Node.
@@ -298,7 +294,7 @@ def trainable(self, value: bool):
298
294
@property
299
295
def feedforward_shapes (self ):
300
296
"""Input data size."""
301
- return self .data_pass_func (self ._feedforward_shapes )
297
+ return self .data_pass . filter (self ._feedforward_shapes )
302
298
303
299
@feedforward_shapes .setter
304
300
def feedforward_shapes (self , size ):
@@ -324,7 +320,7 @@ def set_feedforward_shapes(self, feedforward_shapes: Dict):
324
320
@property
325
321
def feedback_shapes (self ):
326
322
"""Output data size."""
327
- return self .data_pass_func (self ._feedback_shapes )
323
+ return self .data_pass . filter (self ._feedback_shapes )
328
324
329
325
@feedback_shapes .setter
330
326
def feedback_shapes (self , size ):
@@ -530,8 +526,8 @@ def _check_inputs(self, ff, fb=None):
530
526
f'batch size by ".initialize(num_batch)", or change the data '
531
527
f'consistent with the data batch size { self .state .shape [0 ]} .' )
532
528
# data
533
- ff = self .data_pass_func (ff )
534
- fb = self .data_pass_func (fb )
529
+ ff = self .data_pass . filter (ff )
530
+ fb = self .data_pass . filter (fb )
535
531
return ff , fb
536
532
537
533
def _call (self ,
@@ -747,6 +743,8 @@ def set_state(self, state):
747
743
class Network (Node ):
748
744
"""Basic Network class for neural network building in BrainPy."""
749
745
746
+ data_pass = MultipleData ('sequence' )
747
+
750
748
def __init__ (self ,
751
749
nodes : Optional [Sequence [Node ]] = None ,
752
750
ff_edges : Optional [Sequence [Tuple [Node ]]] = None ,
@@ -1145,8 +1143,8 @@ def _check_inputs(self, ff, fb=None):
1145
1143
check_shape_except_batch (size , fb [k ].shape )
1146
1144
1147
1145
# data transformation
1148
- ff = self .data_pass_func (ff )
1149
- fb = self .data_pass_func (fb )
1146
+ ff = self .data_pass . filter (ff )
1147
+ fb = self .data_pass . filter (fb )
1150
1148
return ff , fb
1151
1149
1152
1150
def _call (self ,
@@ -1208,12 +1206,12 @@ def _call(self,
1208
1206
def _call_a_node (self , node , ff , fb , monitors , forced_states ,
1209
1207
parent_outputs , children_queue , ff_senders ,
1210
1208
** shared_kwargs ):
1211
- ff = node .data_pass_func (ff )
1209
+ ff = node .data_pass . filter (ff )
1212
1210
if f'{ node .name } .inputs' in monitors :
1213
1211
monitors [f'{ node .name } .inputs' ] = ff
1214
1212
# get the output results
1215
1213
if len (fb ):
1216
- fb = node .data_pass_func (fb )
1214
+ fb = node .data_pass . filter (fb )
1217
1215
if f'{ node .name } .feedbacks' in monitors :
1218
1216
monitors [f'{ node .name } .feedbacks' ] = fb
1219
1217
parent_outputs [node ] = node .forward (ff , fb , ** shared_kwargs )
@@ -1440,7 +1438,7 @@ def plot_node_graph(self,
1440
1438
if len (nodes_untrainable ):
1441
1439
proxie .append (Line2D ([], [], color = 'white' , marker = 'o' ,
1442
1440
markerfacecolor = untrainable_color ))
1443
- labels .append ('Untrainable ' )
1441
+ labels .append ('Nontrainable ' )
1444
1442
if len (ff_edges ):
1445
1443
proxie .append (Line2D ([], [], color = ff_color , linewidth = 2 ))
1446
1444
labels .append ('Feedforward' )
0 commit comments