Skip to content

Commit 5c56389

Browse files
author
Alexander Ororbia
committed
refactored conv/hebb-conv syn w/ unit-test
1 parent 33e0cc1 commit 5c56389

File tree

3 files changed

+144
-45
lines changed

3 files changed

+144
-45
lines changed

ngclearn/components/synapses/convolution/convSynapse.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn import resolver, Component, Compartment
32
from ngclearn.components.jaxComponent import JaxComponent
4-
import ngclearn.utils.weight_distribution as dist
5-
from ngclearn.components.synapses.convolution.ngcconv import conv2d
3+
from ngcsimlib.compilers.process import transition
4+
from ngcsimlib.component import Component
5+
from ngcsimlib.compartment import Compartment
6+
7+
from ngclearn.utils.weight_distribution import initialize_params
68
from ngcsimlib.logger import info
79
from ngclearn.utils import tensorstats
10+
import ngclearn.utils.weight_distribution as dist
11+
from ngclearn.components.synapses.convolution.ngcconv import conv2d
812

913
class ConvSynapse(JaxComponent): ## base-level convolutional cable
1014
"""
@@ -44,8 +48,9 @@ class ConvSynapse(JaxComponent): ## base-level convolutional cable
4448
"""
4549

4650
# Define Functions
47-
def __init__(self, name, shape, x_shape, filter_init=None, bias_init=None, stride=1,
48-
padding=None, resist_scale=1., batch_size=1, **kwargs):
51+
def __init__(
52+
self, name, shape, x_shape, filter_init=None, bias_init=None, stride=1, padding=None, resist_scale=1.,
53+
batch_size=1, **kwargs):
4954
super().__init__(name, **kwargs)
5055

5156
self.filter_init = filter_init
@@ -95,29 +100,22 @@ def __init__(self, name, shape, x_shape, filter_init=None, bias_init=None, strid
95100
(1, shape[1]))
96101
if bias_init else 0.0)
97102

103+
@transition(output_compartments=["outputs"])
98104
@staticmethod
99-
def _advance_state(Rscale, padding, stride, weights, biases, inputs):
105+
def advance_state(Rscale, padding, stride, weights, biases, inputs):
100106
_x = inputs
101107
outputs = conv2d(_x, weights, stride_size=stride, padding=padding) * Rscale + biases
102108
return outputs
103109

104-
@resolver(_advance_state)
105-
def advance_state(self, outputs):
106-
self.outputs.set(outputs)
107-
110+
@transition(output_compartments=["inputs", "outputs"])
108111
@staticmethod
109-
def _reset(in_shape, out_shape):
112+
def reset(in_shape, out_shape):
110113
preVals = jnp.zeros(in_shape)
111114
postVals = jnp.zeros(out_shape)
112115
inputs = preVals
113116
outputs = postVals
114117
return inputs, outputs
115118

116-
@resolver(_reset)
117-
def reset(self, inputs, outputs):
118-
self.inputs.set(inputs)
119-
self.outputs.set(outputs)
120-
121119
def save(self, directory, **kwargs):
122120
file_name = directory + "/" + self.name + ".npz"
123121
if self.bias_init != None:

ngclearn/components/synapses/convolution/hebbianConvSynapse.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn import resolver, Component, Compartment
2+
from ngcsimlib.compilers.process import transition
3+
from ngcsimlib.component import Component
4+
from ngcsimlib.compartment import Compartment
5+
36
from .convSynapse import ConvSynapse
7+
from ngclearn.utils.weight_distribution import initialize_params
8+
from ngcsimlib.logger import info
9+
from ngclearn.utils import tensorstats
10+
import ngclearn.utils.weight_distribution as dist
411
from ngclearn.components.synapses.convolution.ngcconv import (_conv_same_transpose_padding,
512
_conv_valid_transpose_padding)
613
from ngclearn.components.synapses.convolution.ngcconv import (conv2d, _calc_dX_conv,
@@ -143,8 +150,9 @@ def _init(self, batch_size, x_size, shape, stride, padding, pad_args,
143150
self.x_delta_shape = (dx, dy)
144151

145152
@staticmethod
146-
def _compute_update(sign_value, w_decay, bias_init, stride, pad_args,
147-
delta_shape, pre, post, weights):
153+
def _compute_update(
154+
sign_value, w_decay, bias_init, stride, pad_args, delta_shape, pre, post, weights
155+
): ## synaptic kernel adjustment calculation co-routine
148156
## compute adjustment to filters
149157
dWeights = calc_dK_conv(pre, post, delta_shape=delta_shape,
150158
stride_size=stride, padding=pad_args)
@@ -157,10 +165,12 @@ def _compute_update(sign_value, w_decay, bias_init, stride, pad_args,
157165
dBiases = jnp.sum(post, axis=0, keepdims=True) * sign_value
158166
return dWeights, dBiases
159167

168+
@transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"])
160169
@staticmethod
161-
def _evolve(opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init,
162-
stride, pad_args, delta_shape, pre, post, weights, biases,
163-
opt_params):
170+
def evolve(
171+
opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init, stride, pad_args, delta_shape, pre, post,
172+
weights, biases, opt_params
173+
):
164174
## calc dFilters / dBiases - update to filters and biases
165175
dWeights, dBiases = HebbianConvSynapse._compute_update(
166176
sign_value, w_decay, bias_init, stride, pad_args, delta_shape,
@@ -180,17 +190,11 @@ def _evolve(opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init,
180190
weights = jnp.clip(weights, -w_bounds, w_bounds)
181191
return opt_params, weights, biases, dWeights, dBiases
182192

183-
@resolver(_evolve)
184-
def evolve(self, opt_params, weights, biases, dWeights, dBiases):
185-
self.opt_params.set(opt_params)
186-
self.weights.set(weights)
187-
self.biases.set(biases)
188-
self.dWeights.set(dWeights)
189-
self.dBiases.set(dBiases)
190-
193+
@transition(output_compartments=["dInputs"])
191194
@staticmethod
192-
def _backtransmit(sign_value, x_size, shape, stride, padding, x_delta_shape,
193-
antiPad, post, weights): ## action-backpropagating routine
195+
def backtransmit(
196+
sign_value, x_size, shape, stride, padding, x_delta_shape, antiPad, post, weights
197+
): ## action-backpropagating routine
194198
## calc dInputs - adjustment w.r.t. input signal
195199
k_size, k_size, n_in_chan, n_out_chan = shape
196200
# antiPad = None
@@ -206,12 +210,9 @@ def _backtransmit(sign_value, x_size, shape, stride, padding, x_delta_shape,
206210
dInputs = dInputs * sign_value
207211
return dInputs
208212

209-
@resolver(_backtransmit)
210-
def backtransmit(self, dInputs):
211-
self.dInputs.set(dInputs)
212-
213+
@transition(output_compartments=["inputs", "outputs", "pre", "post", "dInputs"])
213214
@staticmethod
214-
def _reset(in_shape, out_shape):
215+
def reset(in_shape, out_shape):
215216
preVals = jnp.zeros(in_shape)
216217
postVals = jnp.zeros(out_shape)
217218
inputs = preVals
@@ -221,14 +222,6 @@ def _reset(in_shape, out_shape):
221222
dInputs = preVals
222223
return inputs, outputs, pre, post, dInputs
223224

224-
@resolver(_reset)
225-
def reset(self, inputs, outputs, pre, post, dInputs):
226-
self.inputs.set(inputs)
227-
self.outputs.set(outputs)
228-
self.pre.set(pre)
229-
self.post.set(post)
230-
self.dInputs.set(dInputs)
231-
232225
@classmethod
233226
def help(cls): ## component help function
234227
properties = {
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import HebbianConvSynapse
6+
import ngclearn.utils.weight_distribution as dist
7+
from ngcsimlib.compilers import compile_command, wrap_command
8+
from numpy.testing import assert_array_equal
9+
10+
from ngcsimlib.compilers.process import Process, transition
11+
from ngcsimlib.component import Component
12+
from ngcsimlib.compartment import Compartment
13+
from ngcsimlib.context import Context
14+
15+
def test_HebbianConvSynapse1():
16+
name = "hebb_conv_ctx"
17+
## create seeding keys
18+
dkey = random.PRNGKey(1234)
19+
dkey, *subkeys = random.split(dkey, 6)
20+
dt = 1. # ms
21+
padding_style = "SAME"
22+
stride = 1 #2
23+
batch_size = 1
24+
w_size = 2
25+
n_in_chan = 1 #4 # 1
26+
n_out_chan = 1 #5 # 1
27+
x_size = 2 #4
28+
29+
shape = (w_size, w_size, n_in_chan, n_out_chan)
30+
x_shape = (batch_size, x_size, x_size, n_in_chan)
31+
32+
# ---- build a simple Hebb-Conv system ----
33+
with Context(name) as ctx:
34+
a = HebbianConvSynapse(
35+
"a", shape, (x_size, x_size), eta=0., filter_init=dist.constant(value=1.), bias_init=None,
36+
stride=stride, padding=padding_style, batch_size=batch_size, key=subkeys[0]
37+
)
38+
39+
#"""
40+
evolve_process = (Process()
41+
>> a.evolve)
42+
ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
43+
44+
backtransmit_process = (Process()
45+
>> a.backtransmit)
46+
ctx.wrap_and_add_command(jit(backtransmit_process.pure), name="backtransmit")
47+
48+
advance_process = (Process()
49+
>> a.advance_state)
50+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
51+
52+
reset_process = (Process()
53+
>> a.reset)
54+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
55+
#"""
56+
57+
"""
58+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
59+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
60+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
61+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
62+
evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
63+
ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
64+
backpass_cmd, backpass_args = ctx.compile_by_key(a, compile_key="backtransmit")
65+
ctx.add_command(wrap_command(jit(ctx.backtransmit)), name="backtransmit")
66+
"""
67+
68+
x = jnp.ones(x_shape)
69+
70+
ctx.reset()
71+
a.inputs.set(x)
72+
ctx.run(t=1., dt=dt)
73+
y = a.outputs.value
74+
75+
y_truth = jnp.array(
76+
[[[[4.],[2.]],
77+
[[2.], [1.]]]]
78+
)
79+
80+
assert_array_equal(y, y_truth)
81+
# print(y)
82+
# print("======")
83+
84+
# print("NGC-Learn.shape = ", node.outputs.value.shape)
85+
a.pre.set(x)
86+
a.post.set(y)
87+
ctx.adapt(t=1., dt=dt)
88+
dK = a.dWeights.value
89+
#print(dK)
90+
ctx.backtransmit(t=1., dt=dt)
91+
dx = a.dInputs.value
92+
#print(dx)
93+
dK_truth = jnp.array(
94+
[[[[9.]],
95+
[[6.]]],
96+
[[[6.]],
97+
[[4.]]]]
98+
)
99+
dx_truth = jnp.array(
100+
[[[[4.],
101+
[6.]],
102+
[[6.],
103+
[9.]]]]
104+
)
105+
assert_array_equal(dK, dK_truth)
106+
assert_array_equal(dx, dx_truth)
107+
108+
#test_HebbianConvSynapse1()

0 commit comments

Comments
 (0)