Skip to content

Commit ce13554

Browse files
author
Alexander Ororbia
committed
revised/refactored hebb/stdp conv/deconv syn w/ unit-tests
1 parent ff2628c commit ce13554

File tree

7 files changed

+311
-105
lines changed

7 files changed

+311
-105
lines changed

ngclearn/components/synapses/convolution/convSynapse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ def __init__(
9797
if self.bias_init is None:
9898
info(self.name, "is using default bias value of zero (no bias "
9999
"kernel provided)!")
100-
self.biases = Compartment(dist.initialize_params(subkeys[2], bias_init,
101-
(1, shape[1]))
102-
if bias_init else 0.0)
100+
self.biases = Compartment(
101+
dist.initialize_params(subkeys[2], bias_init, (1, shape[1])) if bias_init else 0.0
102+
)
103103

104104
@transition(output_compartments=["outputs"])
105105
@staticmethod

ngclearn/components/synapses/convolution/hebbianConvSynapse.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,7 @@ def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=Non
133133
[self.weights.value, self.biases.value]
134134
if bias_init else [self.weights.value]))
135135

136-
def _init(self, batch_size, x_size, shape, stride, padding, pad_args,
137-
weights):
136+
def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights):
138137
k_size, k_size, n_in_chan, n_out_chan = shape
139138
_x = jnp.zeros((batch_size, x_size, x_size, n_in_chan))
140139
_d = conv2d(_x, weights.value, stride_size=stride, padding=padding) * 0
@@ -155,8 +154,7 @@ def _compute_update(
155154
sign_value, w_decay, bias_init, stride, pad_args, delta_shape, pre, post, weights
156155
): ## synaptic kernel adjustment calculation co-routine
157156
## compute adjustment to filters
158-
dWeights = calc_dK_conv(pre, post, delta_shape=delta_shape,
159-
stride_size=stride, padding=pad_args)
157+
dWeights = calc_dK_conv(pre, post, delta_shape=delta_shape, stride_size=stride, padding=pad_args)
160158
dWeights = dWeights * sign_value
161159
if w_decay > 0.: ## apply synaptic decay
162160
dWeights = dWeights - weights * w_decay
@@ -174,12 +172,10 @@ def evolve(
174172
):
175173
## calc dFilters / dBiases - update to filters and biases
176174
dWeights, dBiases = HebbianConvSynapse._compute_update(
177-
sign_value, w_decay, bias_init, stride, pad_args, delta_shape,
178-
pre, post, weights
175+
sign_value, w_decay, bias_init, stride, pad_args, delta_shape, pre, post, weights
179176
)
180177
if bias_init != None:
181-
opt_params, [weights, biases] = opt(opt_params, [weights, biases],
182-
[dWeights, dBiases])
178+
opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases])
183179
else: ## ignore dBiases since no biases configured
184180
opt_params, [weights] = opt(opt_params, [weights], [dWeights])
185181

@@ -205,8 +201,7 @@ def backtransmit(
205201
# elif padding == "VALID":
206202
# antiPad = _conv_valid_transpose_padding(post.shape[1], x_size,
207203
# k_size, stride)
208-
dInputs = calc_dX_conv(weights, post, delta_shape=x_delta_shape,
209-
stride_size=stride, anti_padding=antiPad)
204+
dInputs = calc_dX_conv(weights, post, delta_shape=x_delta_shape, stride_size=stride, anti_padding=antiPad)
210205
## flip sign of back-transmitted signal (if applicable)
211206
dInputs = dInputs * sign_value
212207
return dInputs

ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights):
143143
def _compute_update(sign_value, w_decay, bias_init, shape, stride, padding, delta_shape, pre, post, weights):
144144
k_size, k_size, n_in_chan, n_out_chan = shape
145145
## compute adjustment to filters
146-
dWeights = calc_dK_deconv(pre, post, delta_shape=delta_shape,
147-
stride_size=stride, out_size=k_size,
148-
padding=padding)
146+
dWeights = calc_dK_deconv(
147+
pre, post, delta_shape=delta_shape, stride_size=stride, out_size=k_size, padding=padding
148+
)
149149
dWeights = dWeights * sign_value
150150
if w_decay > 0.: ## apply synaptic decay
151151
dWeights = dWeights - weights * w_decay
@@ -162,12 +162,10 @@ def evolve(
162162
pre, post, weights, biases, opt_params
163163
):
164164
dWeights, dBiases = HebbianDeconvSynapse._compute_update(
165-
sign_value, w_decay, bias_init, shape, stride, padding, delta_shape,
166-
pre, post, weights
165+
sign_value, w_decay, bias_init, shape, stride, padding, delta_shape, pre, post, weights
167166
)
168167
if bias_init != None:
169-
opt_params, [weights, biases] = opt(opt_params, [weights, biases],
170-
[dWeights, dBiases])
168+
opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases])
171169
else: ## ignore dBiases since no biases configured
172170
opt_params, [weights] = opt(opt_params, [weights], [dWeights])
173171
## apply any enforced filter constraints
@@ -182,8 +180,7 @@ def evolve(
182180
@staticmethod
183181
def backtransmit(sign_value, stride, padding, x_delta_shape, pre, post, weights): ## action-backpropagating routine
184182
## calc dInputs
185-
dInputs = calc_dX_deconv(weights, post, delta_shape=x_delta_shape,
186-
stride_size=stride, padding=padding)
183+
dInputs = calc_dX_deconv(weights, post, delta_shape=x_delta_shape, stride_size=stride, padding=padding)
187184
## flip sign of back-transmitted signal (if applicable)
188185
dInputs = dInputs * sign_value
189186
return dInputs

ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py

Lines changed: 39 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn import resolver, Component, Compartment
2+
from ngcsimlib.compilers.process import transition
3+
from ngcsimlib.component import Component
4+
from ngcsimlib.compartment import Compartment
5+
36
from .convSynapse import ConvSynapse
7+
from ngclearn.utils.weight_distribution import initialize_params
8+
from ngcsimlib.logger import info
9+
from ngclearn.utils import tensorstats
10+
import ngclearn.utils.weight_distribution as dist
411
from ngclearn.components.synapses.convolution.ngcconv import (_conv_same_transpose_padding,
512
_conv_valid_transpose_padding)
613
from ngclearn.components.synapses.convolution.ngcconv import (conv2d, _calc_dX_conv,
@@ -67,12 +74,14 @@ class TraceSTDPConvSynapse(ConvSynapse): ## trace-based STDP convolutional cable
6774
"""
6875

6976
# Define Functions
70-
def __init__(self, name, shape, x_shape, A_plus, A_minus, eta=0.,
71-
pretrace_target=0., filter_init=None, stride=1, padding=None,
72-
resist_scale=1., w_bound=0., w_decay=0., batch_size=1, **kwargs):
73-
super().__init__(name, shape, x_shape=x_shape, filter_init=filter_init,
74-
bias_init=None, resist_scale=resist_scale, stride=stride,
75-
padding=padding, batch_size=batch_size, **kwargs)
77+
def __init__(
78+
self, name, shape, x_shape, A_plus, A_minus, eta=0., pretrace_target=0., filter_init=None, stride=1,
79+
padding=None, resist_scale=1., w_bound=0., w_decay=0., batch_size=1, **kwargs
80+
):
81+
super().__init__(
82+
name, shape, x_shape=x_shape, filter_init=filter_init, bias_init=None, resist_scale=resist_scale,
83+
stride=stride, padding=padding, batch_size=batch_size, **kwargs
84+
)
7685

7786
self.eta = eta
7887
self.w_bound = w_bound ## soft weight constraint
@@ -107,8 +116,7 @@ def __init__(self, name, shape, x_shape, A_plus, A_minus, eta=0.,
107116
self.x_size, k_size, stride)
108117
########################################################################
109118

110-
def _init(self, batch_size, x_size, shape, stride, padding, pad_args,
111-
weights):
119+
def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights):
112120
k_size, k_size, n_in_chan, n_out_chan = shape
113121
_x = jnp.zeros((batch_size, x_size, x_size, n_in_chan))
114122
_d = conv2d(_x, weights.value, stride_size=stride, padding=padding) * 0
@@ -126,26 +134,28 @@ def _init(self, batch_size, x_size, shape, stride, padding, pad_args,
126134
self.x_delta_shape = (dx, dy)
127135

128136
@staticmethod
129-
def _compute_update(pretrace_target, Aplus, Aminus, stride, pad_args,
130-
delta_shape, preSpike, preTrace, postSpike, postTrace):
137+
def _compute_update(
138+
pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace
139+
):
131140
## Compute long-term potentiation to filters
132-
dW_ltp = calc_dK_conv(preTrace - pretrace_target, postSpike * Aplus,
133-
delta_shape=delta_shape, stride_size=stride,
134-
padding=pad_args)
141+
dW_ltp = calc_dK_conv(
142+
preTrace - pretrace_target, postSpike * Aplus, delta_shape=delta_shape, stride_size=stride, padding=pad_args
143+
)
135144
## Compute long-term depression to filters
136-
dW_ltd = -calc_dK_conv(preSpike, postTrace * Aminus,
137-
delta_shape=delta_shape, stride_size=stride,
138-
padding=pad_args)
145+
dW_ltd = -calc_dK_conv(
146+
preSpike, postTrace * Aminus, delta_shape=delta_shape, stride_size=stride, padding=pad_args
147+
)
139148
dWeights = (dW_ltp + dW_ltd)
140149
return dWeights
141150

151+
@transition(output_compartments=["weights", "dWeights"])
142152
@staticmethod
143-
def _evolve(pretrace_target, Aplus, Aminus, w_decay, w_bound,
144-
stride, pad_args, delta_shape, preSpike, preTrace, postSpike,
145-
postTrace, weights, eta):
153+
def evolve(
154+
pretrace_target, Aplus, Aminus, w_decay, w_bound, stride, pad_args, delta_shape, preSpike, preTrace,
155+
postSpike, postTrace, weights, eta
156+
):
146157
dWeights = TraceSTDPConvSynapse._compute_update(
147-
pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape,
148-
preSpike, preTrace, postSpike, postTrace
158+
pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace
149159
)
150160
if w_decay > 0.: ## apply synaptic decay
151161
weights = weights + dWeights * eta - weights * w_decay ## conduct decayed STDP-ascent
@@ -157,14 +167,11 @@ def _evolve(pretrace_target, Aplus, Aminus, w_decay, w_bound,
157167
weights = jnp.clip(weights, eps, w_bound - eps)
158168
return weights, dWeights
159169

160-
@resolver(_evolve)
161-
def evolve(self, weights, dWeights):
162-
self.weights.set(weights)
163-
self.dWeights.set(dWeights)
164-
170+
@transition(output_compartments=["dInputs"])
165171
@staticmethod
166-
def _backtransmit(x_size, shape, stride, padding, x_delta_shape, antiPad,
167-
postSpike, weights): ## action-backpropagating routine
172+
def backtransmit(
173+
x_size, shape, stride, padding, x_delta_shape, antiPad, postSpike, weights
174+
): ## action-backpropagating routine
168175
## calc dInputs - adjustment w.r.t. input signal
169176
k_size, k_size, n_in_chan, n_out_chan = shape
170177
# antiPad = None
@@ -174,16 +181,12 @@ def _backtransmit(x_size, shape, stride, padding, x_delta_shape, antiPad,
174181
# elif padding == "VALID":
175182
# antiPad = _conv_valid_transpose_padding(postSpike.shape[1], x_size,
176183
# k_size, stride)
177-
dInputs = calc_dX_conv(weights, postSpike, delta_shape=x_delta_shape,
178-
stride_size=stride, anti_padding=antiPad)
184+
dInputs = calc_dX_conv(weights, postSpike, delta_shape=x_delta_shape, stride_size=stride, anti_padding=antiPad)
179185
return dInputs
180186

181-
@resolver(_backtransmit)
182-
def backtransmit(self, dInputs):
183-
self.dInputs.set(dInputs)
184-
187+
@transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace"])
185188
@staticmethod
186-
def _reset(in_shape, out_shape):
189+
def reset(in_shape, out_shape):
187190
preVals = jnp.zeros(in_shape)
188191
postVals = jnp.zeros(out_shape)
189192
inputs = preVals
@@ -194,15 +197,6 @@ def _reset(in_shape, out_shape):
194197
postTrace = postVals
195198
return inputs, outputs, preSpike, postSpike, preTrace, postTrace
196199

197-
@resolver(_reset)
198-
def reset(self, inputs, outputs, preSpike, postSpike, preTrace, postTrace):
199-
self.inputs.set(inputs)
200-
self.outputs.set(outputs)
201-
self.preSpike.set(preSpike)
202-
self.postSpike.set(postSpike)
203-
self.preTrace.set(preTrace)
204-
self.postTrace.set(postTrace)
205-
206200
@classmethod
207201
def help(cls): ## component help function
208202
properties = {

ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn import resolver, Component, Compartment
2+
from ngcsimlib.compilers.process import transition
3+
from ngcsimlib.component import Component
4+
from ngcsimlib.compartment import Compartment
5+
36
from .deconvSynapse import DeconvSynapse
7+
from ngclearn.utils.weight_distribution import initialize_params
8+
from ngcsimlib.logger import info
9+
from ngclearn.utils import tensorstats
10+
import ngclearn.utils.weight_distribution as dist
411
from ngclearn.components.synapses.convolution.ngcconv import (deconv2d, _calc_dX_deconv,
512
_calc_dK_deconv, calc_dX_deconv,
613
calc_dK_deconv)
14+
from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn
715

816
class TraceSTDPDeconvSynapse(DeconvSynapse): ## trace-based STDP deconvolutional cable
917
"""
@@ -65,13 +73,14 @@ class TraceSTDPDeconvSynapse(DeconvSynapse): ## trace-based STDP deconvolutional
6573
"""
6674

6775
# Define Functions
68-
def __init__(self, name, shape, x_shape, A_plus, A_minus, eta=0.,
69-
pretrace_target=0., filter_init=None, stride=1, padding=None,
70-
resist_scale=1., w_bound=0., w_decay=0., batch_size=1,
71-
**kwargs):
72-
super().__init__(name, shape, x_shape=x_shape, filter_init=filter_init,
73-
bias_init=None, resist_scale=resist_scale, stride=stride,
74-
padding=padding, batch_size=batch_size, **kwargs)
76+
def __init__(
77+
self, name, shape, x_shape, A_plus, A_minus, eta=0., pretrace_target=0., filter_init=None, stride=1,
78+
padding=None, resist_scale=1., w_bound=0., w_decay=0., batch_size=1, **kwargs
79+
):
80+
super().__init__(
81+
name, shape, x_shape=x_shape, filter_init=filter_init, bias_init=None, resist_scale=resist_scale,
82+
stride=stride, padding=padding, batch_size=batch_size, **kwargs
83+
)
7584

7685
self.eta = eta
7786
self.w_bound = w_bound ## soft weight constraint
@@ -93,12 +102,10 @@ def __init__(self, name, shape, x_shape, A_plus, A_minus, eta=0.,
93102

94103
########################################################################
95104
## Shape error correction -- do shape correction inference (for local updates)
96-
self._init(self.batch_size, self.x_size, self.shape, self.stride,
97-
self.padding, self.pad_args, self.weights)
105+
self._init(self.batch_size, self.x_size, self.shape, self.stride, self.padding, self.pad_args, self.weights)
98106
########################################################################
99107

100-
def _init(self, batch_size, x_size, shape, stride, padding, pad_args,
101-
weights):
108+
def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights):
102109
k_size, k_size, n_in_chan, n_out_chan = shape
103110
_x = jnp.zeros((batch_size, x_size, x_size, n_in_chan))
104111
_d = deconv2d(_x, self.weights.value, stride_size=self.stride,
@@ -117,8 +124,9 @@ def _init(self, batch_size, x_size, shape, stride, padding, pad_args,
117124
self.x_delta_shape = (dx, dy)
118125

119126
@staticmethod
120-
def _compute_update(pretrace_target, Aplus, Aminus, shape, stride, padding,
121-
delta_shape, preSpike, preTrace, postSpike, postTrace):
127+
def _compute_update(
128+
pretrace_target, Aplus, Aminus, shape, stride, padding, delta_shape, preSpike, preTrace, postSpike, postTrace
129+
):
122130
k_size, k_size, n_in_chan, n_out_chan = shape
123131
## calc dFilters
124132
dW_ltp = calc_dK_deconv(preTrace - pretrace_target, postSpike * Aplus,
@@ -130,10 +138,12 @@ def _compute_update(pretrace_target, Aplus, Aminus, shape, stride, padding,
130138
dWeights = (dW_ltp + dW_ltd)
131139
return dWeights
132140

141+
@transition(output_compartments=["weights", "dWeights"])
133142
@staticmethod
134-
def _evolve(pretrace_target, Aplus, Aminus, w_decay, w_bound,
135-
shape, stride, padding, delta_shape, preSpike, preTrace, postSpike,
136-
postTrace, weights, eta):
143+
def evolve(
144+
pretrace_target, Aplus, Aminus, w_decay, w_bound, shape, stride, padding, delta_shape, preSpike, preTrace,
145+
postSpike, postTrace, weights, eta
146+
):
137147
dWeights = TraceSTDPDeconvSynapse._compute_update(
138148
pretrace_target, Aplus, Aminus, shape, stride, padding, delta_shape,
139149
preSpike, preTrace, postSpike, postTrace
@@ -148,25 +158,17 @@ def _evolve(pretrace_target, Aplus, Aminus, w_decay, w_bound,
148158
weights = jnp.clip(weights, eps, w_bound - eps)
149159
return weights, dWeights
150160

151-
@resolver(_evolve)
152-
def evolve(self, weights, dWeights):
153-
self.weights.set(weights)
154-
self.dWeights.set(dWeights)
155-
161+
@transition(output_compartments=["dInputs"])
156162
@staticmethod
157-
def _backtransmit(stride, padding, x_delta_shape, preSpike, postSpike,
158-
weights): ## action-backpropagating routine
163+
def backtransmit(stride, padding, x_delta_shape, preSpike, postSpike, weights): ## action-backpropagating routine
159164
## calc dInputs
160165
dInputs = calc_dX_deconv(weights, postSpike, delta_shape=x_delta_shape,
161166
stride_size=stride, padding=padding)
162167
return dInputs
163168

164-
@resolver(_backtransmit)
165-
def backtransmit(self, dInputs):
166-
self.dInputs.set(dInputs)
167-
169+
@transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace"])
168170
@staticmethod
169-
def _reset(in_shape, out_shape):
171+
def reset(in_shape, out_shape):
170172
preVals = jnp.zeros(in_shape)
171173
postVals = jnp.zeros(out_shape)
172174
inputs = preVals
@@ -177,15 +179,6 @@ def _reset(in_shape, out_shape):
177179
postTrace = postVals
178180
return inputs, outputs, preSpike, postSpike, preTrace, postTrace
179181

180-
@resolver(_reset)
181-
def reset(self, inputs, outputs, preSpike, postSpike, preTrace, postTrace):
182-
self.inputs.set(inputs)
183-
self.outputs.set(outputs)
184-
self.preSpike.set(preSpike)
185-
self.postSpike.set(postSpike)
186-
self.preTrace.set(preTrace)
187-
self.postTrace.set(postTrace)
188-
189182
@classmethod
190183
def help(cls): ## component help function
191184
properties = {

0 commit comments

Comments
 (0)