Skip to content

Commit ff2628c

Browse files
author
Alexander Ororbia
committed
refactored/revised hebb-deconv syn w/ unit-test
1 parent 5c56389 commit ff2628c

File tree

5 files changed

+159
-58
lines changed

5 files changed

+159
-58
lines changed

ngclearn/components/synapses/convolution/convSynapse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ class ConvSynapse(JaxComponent): ## base-level convolutional cable
5050
# Define Functions
5151
def __init__(
5252
self, name, shape, x_shape, filter_init=None, bias_init=None, stride=1, padding=None, resist_scale=1.,
53-
batch_size=1, **kwargs):
53+
batch_size=1, **kwargs
54+
):
5455
super().__init__(name, **kwargs)
5556

5657
self.filter_init = filter_init

ngclearn/components/synapses/convolution/deconvSynapse.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn import resolver, Component, Compartment
32
from ngclearn.components.jaxComponent import JaxComponent
4-
import ngclearn.utils.weight_distribution as dist
5-
from ngclearn.components.synapses.convolution.ngcconv import deconv2d
3+
from ngcsimlib.compilers.process import transition
4+
from ngcsimlib.component import Component
5+
from ngcsimlib.compartment import Compartment
6+
7+
from ngclearn.utils.weight_distribution import initialize_params
68
from ngcsimlib.logger import info
79
from ngclearn.utils import tensorstats
10+
import ngclearn.utils.weight_distribution as dist
11+
from ngclearn.components.synapses.convolution.ngcconv import deconv2d
812

913
class DeconvSynapse(JaxComponent): ## base-level deconvolutional cable
1014
"""
@@ -44,8 +48,10 @@ class DeconvSynapse(JaxComponent): ## base-level deconvolutional cable
4448
"""
4549

4650
# Define Functions
47-
def __init__(self, name, shape, x_shape, filter_init=None, bias_init=None, stride=1,
48-
padding=None, resist_scale=1., batch_size=1, **kwargs):
51+
def __init__(
52+
self, name, shape, x_shape, filter_init=None, bias_init=None, stride=1, padding=None, resist_scale=1.,
53+
batch_size=1, **kwargs
54+
):
4955
super().__init__(name, **kwargs)
5056

5157
self.filter_init = filter_init
@@ -83,29 +89,22 @@ def __init__(self, name, shape, x_shape, filter_init=None, bias_init=None, strid
8389
(1, shape[1]))
8490
if bias_init else 0.0)
8591

92+
@transition(output_compartments=["outputs"])
8693
@staticmethod
87-
def _advance_state(Rscale, padding, stride, weights, biases, inputs):
94+
def advance_state(Rscale, padding, stride, weights, biases, inputs):
8895
_x = inputs
8996
out = deconv2d(_x, weights, stride_size=stride, padding=padding) * Rscale + biases
9097
return out
9198

92-
@resolver(_advance_state)
93-
def advance_state(self, outputs):
94-
self.outputs.set(outputs)
95-
99+
@transition(output_compartments=["inputs", "outputs"])
96100
@staticmethod
97-
def _reset(in_shape, out_shape):
101+
def reset(in_shape, out_shape):
98102
preVals = jnp.zeros(in_shape)
99103
postVals = jnp.zeros(out_shape)
100104
inputs = preVals
101105
outputs = postVals
102106
return inputs, outputs
103107

104-
@resolver(_reset)
105-
def reset(self, inputs, outputs):
106-
self.inputs.set(inputs)
107-
self.outputs.set(outputs)
108-
109108
def save(self, directory, **kwargs):
110109
file_name = directory + "/" + self.name + ".npz"
111110
if self.bias_init != None:

ngclearn/components/synapses/convolution/hebbianConvSynapse.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,10 @@ def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=Non
9292
stride=1, padding=None, resist_scale=1., w_bound=0.,
9393
is_nonnegative=False, w_decay=0., sign_value=1., optim_type="sgd",
9494
batch_size=1, **kwargs):
95-
super().__init__(name, shape, x_shape=x_shape, filter_init=filter_init,
96-
bias_init=bias_init, resist_scale=resist_scale, stride=stride,
97-
padding=padding, batch_size=batch_size, **kwargs)
95+
super().__init__(
96+
name, shape, x_shape=x_shape, filter_init=filter_init, bias_init=bias_init, resist_scale=resist_scale,
97+
stride=stride, padding=padding, batch_size=batch_size, **kwargs
98+
)
9899

99100
self.eta = eta
100101
self.w_bounds = w_bound

ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn import resolver, Component, Compartment
2+
from ngcsimlib.compilers.process import transition
3+
from ngcsimlib.component import Component
4+
from ngcsimlib.compartment import Compartment
5+
36
from .deconvSynapse import DeconvSynapse
7+
from ngclearn.utils.weight_distribution import initialize_params
8+
from ngcsimlib.logger import info
9+
from ngclearn.utils import tensorstats
10+
import ngclearn.utils.weight_distribution as dist
411
from ngclearn.components.synapses.convolution.ngcconv import (deconv2d, _calc_dX_deconv,
512
_calc_dK_deconv, calc_dX_deconv,
613
calc_dK_deconv)
@@ -79,13 +86,15 @@ class HebbianDeconvSynapse(DeconvSynapse): ## Hebbian-evolved deconvolutional ca
7986
"""
8087

8188
# Define Functions
82-
def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=None,
83-
stride=1, padding=None, resist_scale=1., w_bound=0., is_nonnegative=False,
84-
w_decay=0., sign_value=1., optim_type="sgd", batch_size=1, **kwargs):
85-
super().__init__(name, shape, x_shape=x_shape, filter_init=filter_init,
86-
bias_init=bias_init, resist_scale=resist_scale,
87-
stride=stride, padding=padding, batch_size=batch_size,
88-
**kwargs)
89+
def __init__(
90+
self, name, shape, x_shape, eta=0., filter_init=None, bias_init=None, stride=1, padding=None,
91+
resist_scale=1., w_bound=0., is_nonnegative=False, w_decay=0., sign_value=1., optim_type="sgd",
92+
batch_size=1, **kwargs
93+
):
94+
super().__init__(
95+
name, shape, x_shape=x_shape, filter_init=filter_init, bias_init=bias_init, resist_scale=resist_scale,
96+
stride=stride, padding=padding, batch_size=batch_size, **kwargs
97+
)
8998

9099
self.eta = eta
91100
self.w_bounds = w_bound
@@ -112,8 +121,7 @@ def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=Non
112121
[self.weights.value, self.biases.value]
113122
if bias_init else [self.weights.value]))
114123

115-
def _init(self, batch_size, x_size, shape, stride, padding, pad_args,
116-
weights):
124+
def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights):
117125
k_size, k_size, n_in_chan, n_out_chan = shape
118126
_x = jnp.zeros((batch_size, x_size, x_size, n_in_chan))
119127
_d = deconv2d(_x, self.weights.value, stride_size=self.stride,
@@ -132,8 +140,7 @@ def _init(self, batch_size, x_size, shape, stride, padding, pad_args,
132140
self.x_delta_shape = (dx, dy)
133141

134142
@staticmethod
135-
def _compute_update(sign_value, w_decay, bias_init, shape, stride, padding,
136-
delta_shape, pre, post, weights):
143+
def _compute_update(sign_value, w_decay, bias_init, shape, stride, padding, delta_shape, pre, post, weights):
137144
k_size, k_size, n_in_chan, n_out_chan = shape
138145
## compute adjustment to filters
139146
dWeights = calc_dK_deconv(pre, post, delta_shape=delta_shape,
@@ -148,10 +155,12 @@ def _compute_update(sign_value, w_decay, bias_init, shape, stride, padding,
148155
dBiases = jnp.sum(post, axis=0, keepdims=True) * sign_value
149156
return dWeights, dBiases
150157

158+
@transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"])
151159
@staticmethod
152-
def _evolve(opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init,
153-
shape, stride, padding, delta_shape, pre, post, weights, biases,
154-
opt_params):
160+
def evolve(
161+
opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init, shape, stride, padding, delta_shape,
162+
pre, post, weights, biases, opt_params
163+
):
155164
dWeights, dBiases = HebbianDeconvSynapse._compute_update(
156165
sign_value, w_decay, bias_init, shape, stride, padding, delta_shape,
157166
pre, post, weights
@@ -169,30 +178,19 @@ def _evolve(opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init,
169178
weights = jnp.clip(weights, -w_bounds, w_bounds)
170179
return opt_params, weights, biases, dWeights, dBiases
171180

172-
@resolver(_evolve)
173-
def evolve(self, opt_params, weights, biases, dWeights, dBiases):
174-
self.opt_params.set(opt_params)
175-
self.weights.set(weights)
176-
self.biases.set(biases)
177-
self.dWeights.set(dWeights)
178-
self.dBiases.set(dBiases)
179-
181+
@transition(output_compartments=["dInputs"])
180182
@staticmethod
181-
def _backtransmit(sign_value, stride, padding, x_delta_shape, pre, post,
182-
weights): ## action-backpropagating routine
183+
def backtransmit(sign_value, stride, padding, x_delta_shape, pre, post, weights): ## action-backpropagating routine
183184
## calc dInputs
184185
dInputs = calc_dX_deconv(weights, post, delta_shape=x_delta_shape,
185186
stride_size=stride, padding=padding)
186187
## flip sign of back-transmitted signal (if applicable)
187188
dInputs = dInputs * sign_value
188189
return dInputs
189190

190-
@resolver(_backtransmit)
191-
def backtransmit(self, dInputs):
192-
self.dInputs.set(dInputs)
193-
191+
@transition(output_compartments=["inputs", "outputs", "pre", "post", "dInputs"])
194192
@staticmethod
195-
def _reset(in_shape, out_shape):
193+
def reset(in_shape, out_shape):
196194
preVals = jnp.zeros(in_shape)
197195
postVals = jnp.zeros(out_shape)
198196
inputs = preVals
@@ -202,14 +200,6 @@ def _reset(in_shape, out_shape):
202200
dInputs = preVals
203201
return inputs, outputs, pre, post, dInputs
204202

205-
@resolver(_reset)
206-
def reset(self, inputs, outputs, pre, post, dInputs):
207-
self.inputs.set(inputs)
208-
self.outputs.set(outputs)
209-
self.pre.set(pre)
210-
self.post.set(post)
211-
self.dInputs.set(dInputs)
212-
213203
@classmethod
214204
def help(cls): ## component help function
215205
properties = {
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import HebbianDeconvSynapse
6+
import ngclearn.utils.weight_distribution as dist
7+
from ngcsimlib.compilers import compile_command, wrap_command
8+
from numpy.testing import assert_array_equal
9+
10+
from ngcsimlib.compilers.process import Process, transition
11+
from ngcsimlib.component import Component
12+
from ngcsimlib.compartment import Compartment
13+
from ngcsimlib.context import Context
14+
15+
def test_HebbianDeconvSynapse1():
16+
name = "hebb_deconv_ctx"
17+
## create seeding keys
18+
dkey = random.PRNGKey(1234)
19+
dkey, *subkeys = random.split(dkey, 6)
20+
dt = 1. # ms
21+
padding_style = "SAME"
22+
stride = 1 #2
23+
batch_size = 1
24+
w_size = 2
25+
n_in_chan = 1 #4 # 1
26+
n_out_chan = 1 #5 # 1
27+
x_size = 2 #4
28+
29+
shape = (w_size, w_size, n_in_chan, n_out_chan)
30+
x_shape = (batch_size, x_size, x_size, n_in_chan)
31+
32+
# ---- build a simple Hebb-Conv system ----
33+
with Context(name) as ctx:
34+
a = HebbianDeconvSynapse(
35+
"a", shape, (x_size, x_size), eta=0., filter_init=dist.constant(value=1.), bias_init=None,
36+
stride=stride, padding=padding_style, batch_size=batch_size, key=subkeys[0]
37+
)
38+
39+
#"""
40+
evolve_process = (Process()
41+
>> a.evolve)
42+
#ctx.wrap_and_add_command(evolve_process.pure, name="run")
43+
ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
44+
45+
backtransmit_process = (Process()
46+
>> a.backtransmit)
47+
ctx.wrap_and_add_command(jit(backtransmit_process.pure), name="backtransmit")
48+
49+
advance_process = (Process()
50+
>> a.advance_state)
51+
# ctx.wrap_and_add_command(advance_process.pure, name="run")
52+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
53+
54+
reset_process = (Process()
55+
>> a.reset)
56+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
57+
#"""
58+
59+
"""
60+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
61+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
62+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
63+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
64+
evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
65+
ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
66+
backpass_cmd, backpass_args = ctx.compile_by_key(a, compile_key="backtransmit")
67+
ctx.add_command(wrap_command(jit(ctx.backtransmit)), name="backtransmit")
68+
"""
69+
70+
x = jnp.ones(x_shape)
71+
72+
ctx.reset()
73+
a.inputs.set(x)
74+
ctx.run(t=1., dt=dt)
75+
y = a.outputs.value
76+
77+
y_truth = jnp.array(
78+
[[[[1.],[2.]],
79+
[[2.], [4.]]]]
80+
)
81+
82+
assert_array_equal(y, y_truth)
83+
#print(y)
84+
#print("======")
85+
86+
# print("NGC-Learn.shape = ", node.outputs.value.shape)
87+
a.pre.set(x)
88+
a.post.set(y)
89+
ctx.adapt(t=1., dt=dt)
90+
dK = a.dWeights.value
91+
#print(dK)
92+
ctx.backtransmit(t=1., dt=dt)
93+
dx = a.dInputs.value
94+
#print(dx)
95+
dK_truth = jnp.array(
96+
[[[[4.]],
97+
[[6.]]],
98+
[[[6.]],
99+
[[9.]]]]
100+
)
101+
dx_truth = jnp.array(
102+
[[[[9.],
103+
[6.]],
104+
[[6.],
105+
[4.]]]]
106+
)
107+
assert_array_equal(dK, dK_truth)
108+
assert_array_equal(dx, dx_truth)
109+
110+
#test_HebbianDeconvSynapse1()

0 commit comments

Comments
 (0)