1
1
# -*- coding: utf-8 -*-
2
2
3
- """
4
- adapted from jax.example_libraries.stax.BatchNorm
5
- https://jax.readthedocs.io/en/latest/_modules/jax/example_libraries/stax.html#BatchNorm
6
- """
7
-
8
-
9
3
from typing import Union
10
4
11
5
import jax .nn
@@ -29,14 +23,23 @@ class BatchNorm(Node):
29
23
Most commonly, the first axis of the data is the batch, and the last is
30
24
the channel. However, users can specify the axes to be normalized.
31
25
26
+ adapted from jax.example_libraries.stax.BatchNorm
27
+ https://jax.readthedocs.io/en/latest/_modules/jax/example_libraries/stax.html#BatchNorm
28
+
32
29
Parameters
33
30
----------
34
- axis: axes where the data will be normalized. The axis of channels should be excluded.
35
- epsilon: a value added to the denominator for numerical stability. Default: 1e-5
36
- translate: whether to translate data in refactoring
37
- scale: whether to scale data in refactoring
38
- beta_init: an initializer generating the original translation matrix
39
- gamma_init: an initializer generating the original scaling matrix
31
+ axis: int, tuple, list
32
+ axes where the data will be normalized. The axis of channels should be excluded.
33
+ epsilon: float
34
+ a value added to the denominator for numerical stability. Default: 1e-5
35
+ translate: bool
36
+ whether to translate data in refactoring
37
+ scale: bool
38
+ whether to scale data in refactoring
39
+ beta_init: brainpy.init.Initializer
40
+ an initializer generating the original translation matrix
41
+ gamma_init: brainpy.init.Initializer
42
+ an initializer generating the original scaling matrix
40
43
"""
41
44
def __init__ (self ,
42
45
axis : Union [int , tuple , list ],
@@ -86,10 +89,14 @@ class BatchNorm1d(BatchNorm):
86
89
axes where the data will be normalized. The axis of channels should be excluded.
87
90
epsilon: float
88
91
a value added to the denominator for numerical stability. Default: 1e-5
89
- translate: whether to translate data in refactoring
90
- scale: whether to scale data in refactoring
91
- beta_init: an initializer generating the original translation matrix
92
- gamma_init: an initializer generating the original scaling matrix
92
+ translate: bool
93
+ whether to translate data in refactoring
94
+ scale: bool
95
+ whether to scale data in refactoring
96
+ beta_init: brainpy.init.Initializer
97
+ an initializer generating the original translation matrix
98
+ gamma_init: brainpy.init.Initializer
99
+ an initializer generating the original scaling matrix
93
100
"""
94
101
def __init__ (self , axis = (0 , 1 ), ** kwargs ):
95
102
super (BatchNorm1d , self ).__init__ (axis = axis , ** kwargs )
@@ -138,20 +145,24 @@ def _check_input_dim(self):
138
145
139
146
class BatchNorm3d (BatchNorm ):
140
147
"""3-D batch normalization.
141
- The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
142
- `h` is the height dimension, `w` is the width dimension, `d` is the depth
143
- dimension, and `c` is the channel dimension.
144
-
145
- Parameters
146
- ----------
147
- axis: int, tuple, list
148
- axes where the data will be normalized. The axis of channels should be excluded.
149
- epsilon: float
150
- a value added to the denominator for numerical stability. Default: 1e-5
151
- translate: whether to translate data in refactoring
152
- scale: whether to scale data in refactoring
153
- beta_init: an initializer generating the original translation matrix
154
- gamma_init: an initializer generating the original scaling matrix
148
+ The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
149
+ `h` is the height dimension, `w` is the width dimension, `d` is the depth
150
+ dimension, and `c` is the channel dimension.
151
+
152
+ Parameters
153
+ ----------
154
+ axis: int, tuple, list
155
+ axes where the data will be normalized. The axis of channels should be excluded.
156
+ epsilon: float
157
+ a value added to the denominator for numerical stability. Default: 1e-5
158
+ translate: bool
159
+ whether to translate data in refactoring
160
+ scale: bool
161
+ whether to scale data in refactoring
162
+ beta_init: brainpy.init.Initializer
163
+ an initializer generating the original translation matrix
164
+ gamma_init: brainpy.init.Initializer
165
+ an initializer generating the original scaling matrix
155
166
"""
156
167
def __init__ (self , axis = (0 , 1 , 2 , 3 ), ** kwargs ):
157
168
super (BatchNorm3d , self ).__init__ (axis = axis , ** kwargs )
0 commit comments