Skip to content

Commit 89f9b65

Browse files
authored
data pass of the Node is default SingleData (#148)
2 parents 47b7539 + e35b09d commit 89f9b65

18 files changed

+430
-179
lines changed

brainpy/nn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""Neural Networks (nn)"""
44

55
from .base import *
6-
from .constants import *
6+
from .datatypes import *
77
from .graph_flow import *
88
from .nodes import *
99
from .graph_flow import *

brainpy/nn/base.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@
2828
MathError)
2929
from brainpy.nn.algorithms.offline import OfflineAlgorithm
3030
from brainpy.nn.algorithms.online import OnlineAlgorithm
31-
from brainpy.nn.constants import (PASS_SEQUENCE,
32-
DATA_PASS_FUNC,
33-
DATA_PASS_TYPES)
31+
from brainpy.nn.datatypes import (DataType, SingleData, MultipleData)
3432
from brainpy.nn.graph_flow import (find_senders_and_receivers,
3533
find_entries_and_exits,
3634
detect_cycle,
@@ -83,13 +81,13 @@ def feedback(self):
8381
class Node(Base):
8482
"""Basic Node class for neural network building in BrainPy."""
8583

86-
'''Support multiple types of data pass, including "PASS_SEQUENCE" (by default),
87-
"PASS_NAME_DICT", "PASS_NODE_DICT" and user-customized type which registered
88-
by ``brainpy.nn.register_data_pass_type()`` function.
84+
'''Support multiple types of data pass, including "PassOnlyOne" (by default),
85+
"PassSequence", "PassNameDict", etc. and user-customized type which inherits
86+
from basic "SingleData" or "MultipleData".
8987
9088
This setting will change the feedforward/feedback input data which pass into
9189
the "call()" function and the sizes of the feedforward/feedback input data.'''
92-
data_pass_type = PASS_SEQUENCE
90+
data_pass = SingleData()
9391

9492
'''Offline fitting method.'''
9593
offline_fit_by: Union[Callable, OfflineAlgorithm]
@@ -115,11 +113,10 @@ def __init__(
115113
self._trainable = trainable
116114
self._state = None # the state of the current node
117115
self._fb_output = None # the feedback output of the current node
118-
# data pass function
119-
if self.data_pass_type not in DATA_PASS_FUNC:
120-
raise ValueError(f'Unsupported data pass type {self.data_pass_type}. '
121-
f'Only support {DATA_PASS_TYPES}')
122-
self.data_pass_func = DATA_PASS_FUNC[self.data_pass_type]
116+
# data pass
117+
if not isinstance(self.data_pass, DataType):
118+
raise ValueError(f'Unsupported data pass type {type(self.data_pass)}. '
119+
f'Only support {DataType.__class__}')
123120

124121
# super initialization
125122
super(Node, self).__init__(name=name)
@@ -129,11 +126,10 @@ def __init__(
129126
self._feedforward_shapes = {self.name: (None,) + tools.to_size(input_shape)}
130127

131128
def __repr__(self):
132-
name = type(self).__name__
133-
prefix = ' ' * (len(name) + 1)
134-
line1 = f"{name}(name={self.name}, forwards={self.feedforward_shapes}, \n"
135-
line2 = f"{prefix}feedbacks={self.feedback_shapes}, output={self.output_shape})"
136-
return line1 + line2
129+
return (f"{type(self).__name__}(name={self.name}, "
130+
f"forwards={self.feedforward_shapes}, "
131+
f"feedbacks={self.feedback_shapes}, "
132+
f"output={self.output_shape})")
137133

138134
def __call__(self, *args, **kwargs) -> Tensor:
139135
"""The main computation function of a Node.
@@ -298,7 +294,7 @@ def trainable(self, value: bool):
298294
@property
299295
def feedforward_shapes(self):
300296
"""Input data size."""
301-
return self.data_pass_func(self._feedforward_shapes)
297+
return self.data_pass.filter(self._feedforward_shapes)
302298

303299
@feedforward_shapes.setter
304300
def feedforward_shapes(self, size):
@@ -324,7 +320,7 @@ def set_feedforward_shapes(self, feedforward_shapes: Dict):
324320
@property
325321
def feedback_shapes(self):
326322
"""Output data size."""
327-
return self.data_pass_func(self._feedback_shapes)
323+
return self.data_pass.filter(self._feedback_shapes)
328324

329325
@feedback_shapes.setter
330326
def feedback_shapes(self, size):
@@ -530,8 +526,8 @@ def _check_inputs(self, ff, fb=None):
530526
f'batch size by ".initialize(num_batch)", or change the data '
531527
f'consistent with the data batch size {self.state.shape[0]}.')
532528
# data
533-
ff = self.data_pass_func(ff)
534-
fb = self.data_pass_func(fb)
529+
ff = self.data_pass.filter(ff)
530+
fb = self.data_pass.filter(fb)
535531
return ff, fb
536532

537533
def _call(self,
@@ -747,6 +743,8 @@ def set_state(self, state):
747743
class Network(Node):
748744
"""Basic Network class for neural network building in BrainPy."""
749745

746+
data_pass = MultipleData('sequence')
747+
750748
def __init__(self,
751749
nodes: Optional[Sequence[Node]] = None,
752750
ff_edges: Optional[Sequence[Tuple[Node]]] = None,
@@ -1145,8 +1143,8 @@ def _check_inputs(self, ff, fb=None):
11451143
check_shape_except_batch(size, fb[k].shape)
11461144

11471145
# data transformation
1148-
ff = self.data_pass_func(ff)
1149-
fb = self.data_pass_func(fb)
1146+
ff = self.data_pass.filter(ff)
1147+
fb = self.data_pass.filter(fb)
11501148
return ff, fb
11511149

11521150
def _call(self,
@@ -1208,12 +1206,12 @@ def _call(self,
12081206
def _call_a_node(self, node, ff, fb, monitors, forced_states,
12091207
parent_outputs, children_queue, ff_senders,
12101208
**shared_kwargs):
1211-
ff = node.data_pass_func(ff)
1209+
ff = node.data_pass.filter(ff)
12121210
if f'{node.name}.inputs' in monitors:
12131211
monitors[f'{node.name}.inputs'] = ff
12141212
# get the output results
12151213
if len(fb):
1216-
fb = node.data_pass_func(fb)
1214+
fb = node.data_pass.filter(fb)
12171215
if f'{node.name}.feedbacks' in monitors:
12181216
monitors[f'{node.name}.feedbacks'] = fb
12191217
parent_outputs[node] = node.forward(ff, fb, **shared_kwargs)
@@ -1440,7 +1438,7 @@ def plot_node_graph(self,
14401438
if len(nodes_untrainable):
14411439
proxie.append(Line2D([], [], color='white', marker='o',
14421440
markerfacecolor=untrainable_color))
1443-
labels.append('Untrainable')
1441+
labels.append('Nontrainable')
14441442
if len(ff_edges):
14451443
proxie.append(Line2D([], [], color=ff_color, linewidth=2))
14461444
labels.append('Feedforward')

brainpy/nn/constants.py

Lines changed: 0 additions & 114 deletions
This file was deleted.

brainpy/nn/datatypes.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
__all__ = [
5+
# data types
6+
'DataType',
7+
8+
# pass rules
9+
'SingleData',
10+
'MultipleData',
11+
]
12+
13+
14+
class DataType(object):
15+
"""Base class for data type."""
16+
17+
def filter(self, data):
18+
raise NotImplementedError
19+
20+
def __repr__(self):
21+
return self.__class__.__name__
22+
23+
24+
class SingleData(DataType):
25+
"""Pass the only one data into the node.
26+
If there are multiple data, an error will be raised. """
27+
28+
def filter(self, data):
29+
if data is None:
30+
return None
31+
if len(data) > 1:
32+
raise ValueError(f'{self.__class__.__name__} only support one '
33+
f'feedforward/feedback input. But we got {len(data)}.')
34+
return tuple(data.values())[0]
35+
36+
def __repr__(self):
37+
return self.__class__.__name__
38+
39+
40+
class MultipleData(DataType):
41+
"""Pass a list/tuple of data into the node."""
42+
43+
def __init__(self, return_type: str = 'sequence'):
44+
if return_type not in ['sequence', 'name_dict', 'type_dict', 'node_dict']:
45+
raise ValueError(f"Only support return type of 'sequence', 'name_dict', "
46+
f"'type_dict' and 'node_dict'. But we got {return_type}")
47+
self.return_type = return_type
48+
49+
from brainpy.nn.base import Node
50+
51+
if return_type == 'sequence':
52+
f = lambda data: tuple(data.values())
53+
54+
elif return_type == 'name_dict':
55+
# Pass a dict with <node name, data> into the node.
56+
57+
def f(data):
58+
_res = dict()
59+
for node, val in data.items():
60+
if isinstance(node, str):
61+
_res[node] = val
62+
elif isinstance(node, Node):
63+
_res[node.name] = val
64+
else:
65+
raise ValueError(f'Unknown type {type(node)}: node')
66+
return _res
67+
68+
elif return_type == 'type_dict':
69+
# Pass a dict with <node type, data> into the node.
70+
71+
def f(data):
72+
_res = dict()
73+
for node, val in data.items():
74+
if isinstance(node, str):
75+
_res[str] = val
76+
elif isinstance(node, Node):
77+
_res[type(node.name)] = val
78+
else:
79+
raise ValueError(f'Unknown type {type(node)}: node')
80+
return _res
81+
82+
elif return_type == 'node_dict':
83+
# Pass a dict with <node, data> into the node.
84+
f = lambda data: data
85+
86+
else:
87+
raise ValueError
88+
self.return_func = f
89+
90+
def __repr__(self):
91+
return f'{self.__class__.__name__}(return_type={self.return_type})'
92+
93+
def filter(self, data):
94+
if data is None:
95+
return None
96+
else:
97+
return self.return_func(data)

brainpy/nn/nodes/ANN/batch_norm.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
# -*- coding: utf-8 -*-
22

3-
from typing import Sequence, Optional, Dict, Callable, Union
3+
from typing import Union
44

55
import jax.nn
66
import jax.numpy as jnp
77

8-
import brainpy.math as bm
98
import brainpy
9+
import brainpy.math as bm
1010
from brainpy.initialize import ZeroInit, OneInit, Initializer
1111
from brainpy.nn.base import Node
12-
from brainpy.nn.constants import PASS_ONLY_ONE
13-
1412

1513
__all__ = [
1614
'BatchNorm',
@@ -43,8 +41,6 @@ class BatchNorm(Node):
4341
gamma_init: brainpy.init.Initializer
4442
an initializer generating the original scaling matrix
4543
"""
46-
data_pass_type = PASS_ONLY_ONE
47-
4844
def __init__(self,
4945
axis: Union[int, tuple, list],
5046
epsilon: float = 1e-5,

0 commit comments

Comments
 (0)