Skip to content

Commit e35b09d

Browse files
authored
Merge branch 'master' into updates
2 parents 822d56f + 47b7539 commit e35b09d

File tree

7 files changed

+78
-56
lines changed

7 files changed

+78
-56
lines changed

brainpy/nn/nodes/ANN/batch_norm.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
# -*- coding: utf-8 -*-
22

3-
"""
4-
adapted from jax.example_libraries.stax.BatchNorm
5-
https://jax.readthedocs.io/en/latest/_modules/jax/example_libraries/stax.html#BatchNorm
6-
"""
7-
8-
93
from typing import Union
104

115
import jax.nn
@@ -29,14 +23,23 @@ class BatchNorm(Node):
2923
Most commonly, the first axis of the data is the batch, and the last is
3024
the channel. However, users can specify the axes to be normalized.
3125
26+
adapted from jax.example_libraries.stax.BatchNorm
27+
https://jax.readthedocs.io/en/latest/_modules/jax/example_libraries/stax.html#BatchNorm
28+
3229
Parameters
3330
----------
34-
axis: axes where the data will be normalized. The axis of channels should be excluded.
35-
epsilon: a value added to the denominator for numerical stability. Default: 1e-5
36-
translate: whether to translate data in refactoring
37-
scale: whether to scale data in refactoring
38-
beta_init: an initializer generating the original translation matrix
39-
gamma_init: an initializer generating the original scaling matrix
31+
axis: int, tuple, list
32+
axes where the data will be normalized. The axis of channels should be excluded.
33+
epsilon: float
34+
a value added to the denominator for numerical stability. Default: 1e-5
35+
translate: bool
36+
whether to translate data in refactoring
37+
scale: bool
38+
whether to scale data in refactoring
39+
beta_init: brainpy.init.Initializer
40+
an initializer generating the original translation matrix
41+
gamma_init: brainpy.init.Initializer
42+
an initializer generating the original scaling matrix
4043
"""
4144
def __init__(self,
4245
axis: Union[int, tuple, list],
@@ -86,10 +89,14 @@ class BatchNorm1d(BatchNorm):
8689
axes where the data will be normalized. The axis of channels should be excluded.
8790
epsilon: float
8891
a value added to the denominator for numerical stability. Default: 1e-5
89-
translate: whether to translate data in refactoring
90-
scale: whether to scale data in refactoring
91-
beta_init: an initializer generating the original translation matrix
92-
gamma_init: an initializer generating the original scaling matrix
92+
translate: bool
93+
whether to translate data in refactoring
94+
scale: bool
95+
whether to scale data in refactoring
96+
beta_init: brainpy.init.Initializer
97+
an initializer generating the original translation matrix
98+
gamma_init: brainpy.init.Initializer
99+
an initializer generating the original scaling matrix
93100
"""
94101
def __init__(self, axis=(0, 1), **kwargs):
95102
super(BatchNorm1d, self).__init__(axis=axis, **kwargs)
@@ -138,20 +145,24 @@ def _check_input_dim(self):
138145

139146
class BatchNorm3d(BatchNorm):
140147
"""3-D batch normalization.
141-
The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
142-
`h` is the height dimension, `w` is the width dimension, `d` is the depth
143-
dimension, and `c` is the channel dimension.
144-
145-
Parameters
146-
----------
147-
axis: int, tuple, list
148-
axes where the data will be normalized. The axis of channels should be excluded.
149-
epsilon: float
150-
a value added to the denominator for numerical stability. Default: 1e-5
151-
translate: whether to translate data in refactoring
152-
scale: whether to scale data in refactoring
153-
beta_init: an initializer generating the original translation matrix
154-
gamma_init: an initializer generating the original scaling matrix
148+
The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
149+
`h` is the height dimension, `w` is the width dimension, `d` is the depth
150+
dimension, and `c` is the channel dimension.
151+
152+
Parameters
153+
----------
154+
axis: int, tuple, list
155+
axes where the data will be normalized. The axis of channels should be excluded.
156+
epsilon: float
157+
a value added to the denominator for numerical stability. Default: 1e-5
158+
translate: bool
159+
whether to translate data in refactoring
160+
scale: bool
161+
whether to scale data in refactoring
162+
beta_init: brainpy.init.Initializer
163+
an initializer generating the original translation matrix
164+
gamma_init: brainpy.init.Initializer
165+
an initializer generating the original scaling matrix
155166
"""
156167
def __init__(self, axis=(0, 1, 2, 3), **kwargs):
157168
super(BatchNorm3d, self).__init__(axis=axis, **kwargs)

brainpy/nn/nodes/ANN/tests/test_batchnorm.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -76,26 +76,26 @@ def test_batchnorm1(self):
7676

7777
print(model(inputs))
7878

79-
# def test_batchnorm2(self):
80-
# i = bp.nn.Input((3, 4))
81-
# b = bp.nn.BatchNorm(axis=(0, 2)) # channel axis: 1
82-
# f = bp.nn.Reshape((-1, 12))
83-
# o = bp.nn.GeneralDense(2)
84-
# model = i >> b >>f >> o
85-
# model.initialize(num_batch=2)
86-
#
87-
# inputs = bp.math.ones((2, 3, 4))
88-
# inputs[0, 0, :] = 2.
89-
# inputs[0, 1, 0] = 5.
90-
# # print(inputs)
91-
# print(model(inputs))
92-
#
93-
#
94-
# X = bp.math.random.random((1000, 10, 3, 4))
95-
# Y = bp.math.random.randint(0, 2, (1000, 10, 2))
96-
# trainer = bp.nn.BPTT(model,
97-
# loss=bp.losses.cross_entropy_loss,
98-
# optimizer=bp.optim.Adam(lr=1e-3))
99-
# trainer.fit([X, Y])
79+
def test_batchnorm2(self):
80+
i = bp.nn.Input((3, 4))
81+
b = bp.nn.BatchNorm(axis=(0, 2)) # channel axis: 1
82+
f = bp.nn.Reshape((-1, 12))
83+
o = bp.nn.GeneralDense(2)
84+
model = i >> b >>f >> o
85+
model.initialize(num_batch=2)
86+
87+
inputs = bp.math.ones((2, 3, 4))
88+
inputs[0, 0, :] = 2.
89+
inputs[0, 1, 0] = 5.
90+
# print(inputs)
91+
print(model(inputs))
92+
93+
94+
X = bp.math.random.random((1000, 10, 3, 4))
95+
Y = bp.math.random.randint(0, 2, (1000, 10, 2))
96+
trainer = bp.nn.BPTT(model,
97+
loss=bp.losses.cross_entropy_loss,
98+
optimizer=bp.optim.Adam(lr=1e-3))
99+
trainer.fit([X, Y])
100100

101101

extensions/brainpylib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
__version__ = "0.0.4"
3+
__version__ = "0.0.5"
44

55
# IMPORTANT, must import first
66
from . import register_custom_calls

extensions/changelog.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
Release notes (brainpylib)
22
##########################
33

4+
Version 0.0.5
5+
=============
6+
7+
- Support operator customization on GPU by ``numba``
8+
9+
10+
Version 0.0.4
11+
=============
12+
13+
- Support operator customization on CPU by ``numba``
14+
415

516
Version 0.0.3
617
=============

extensions/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
author_email='[email protected]',
3535
packages=find_packages(exclude=['lib*']),
3636
include_package_data=True,
37-
install_requires=["jax", "jaxlib", "pybind11>=2.6, <2.8"],
37+
install_requires=["jax", "jaxlib", "pybind11>=2.6, <2.8", "cffi", "numba"],
3838
extras_require={"test": "pytest"},
3939
python_requires='>=3.7',
4040
url='https://github.com/PKU-NIP-Lab/BrainPy',

extensions/setup_cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def build_extension(self, ext):
9090
author_email='[email protected]',
9191
packages=find_packages(exclude=['lib*']),
9292
include_package_data=True,
93-
install_requires=["jax", "jaxlib"],
93+
install_requires=["jax", "jaxlib", "pybind11>=2.6, <2.8", "cffi", "numba"],
9494
extras_require={"test": "pytest"},
9595
python_requires='>=3.7',
9696
url='https://github.com/PKU-NIP-Lab/BrainPy',

extensions/setup_mac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
sources=["lib/cpu_ops.cc"] + glob.glob("lib/*_cpu.cc"),
2323
cxx_std=11,
2424
# extra_link_args=["-rpath", "/Users/ztqakita/miniforge3/lib"], # m1
25-
extra_link_args=["-rpath", "/Users/ztqakita/miniforge3/lib"], # intel
25+
extra_link_args=["-rpath", "/Users/ztqakita/opt/miniconda3/lib"], # intel
2626
define_macros=[('VERSION_INFO', __version__)]),
2727
]
2828

@@ -36,7 +36,7 @@
3636
author_email='[email protected]',
3737
packages=find_packages(exclude=['lib*']),
3838
include_package_data=True,
39-
install_requires=["jax", "jaxlib", "pybind11>=2.6, <2.8"],
39+
install_requires=["jax", "jaxlib", "pybind11>=2.6, <2.8", "cffi", "numba"],
4040
extras_require={"test": "pytest"},
4141
python_requires='>=3.7',
4242
url='https://github.com/PKU-NIP-Lab/BrainPy',

0 commit comments

Comments
 (0)