1
1
from jax import random , numpy as jnp , jit
2
- from ngclearn import resolver , Component , Compartment
2
+ from ngcsimlib .compilers .process import transition
3
+ from ngcsimlib .component import Component
4
+ from ngcsimlib .compartment import Compartment
5
+
3
6
from .convSynapse import ConvSynapse
7
+ from ngclearn .utils .weight_distribution import initialize_params
8
+ from ngcsimlib .logger import info
9
+ from ngclearn .utils import tensorstats
10
+ import ngclearn .utils .weight_distribution as dist
4
11
from ngclearn .components .synapses .convolution .ngcconv import (_conv_same_transpose_padding ,
5
12
_conv_valid_transpose_padding )
6
13
from ngclearn .components .synapses .convolution .ngcconv import (conv2d , _calc_dX_conv ,
@@ -143,8 +150,9 @@ def _init(self, batch_size, x_size, shape, stride, padding, pad_args,
143
150
self .x_delta_shape = (dx , dy )
144
151
145
152
@staticmethod
146
- def _compute_update (sign_value , w_decay , bias_init , stride , pad_args ,
147
- delta_shape , pre , post , weights ):
153
+ def _compute_update (
154
+ sign_value , w_decay , bias_init , stride , pad_args , delta_shape , pre , post , weights
155
+ ): ## synaptic kernel adjustment calculation co-routine
148
156
## compute adjustment to filters
149
157
dWeights = calc_dK_conv (pre , post , delta_shape = delta_shape ,
150
158
stride_size = stride , padding = pad_args )
@@ -157,10 +165,12 @@ def _compute_update(sign_value, w_decay, bias_init, stride, pad_args,
157
165
dBiases = jnp .sum (post , axis = 0 , keepdims = True ) * sign_value
158
166
return dWeights , dBiases
159
167
168
+ @transition (output_compartments = ["opt_params" , "weights" , "biases" , "dWeights" , "dBiases" ])
160
169
@staticmethod
161
- def _evolve (opt , sign_value , w_decay , w_bounds , is_nonnegative , bias_init ,
162
- stride , pad_args , delta_shape , pre , post , weights , biases ,
163
- opt_params ):
170
+ def evolve (
171
+ opt , sign_value , w_decay , w_bounds , is_nonnegative , bias_init , stride , pad_args , delta_shape , pre , post ,
172
+ weights , biases , opt_params
173
+ ):
164
174
## calc dFilters / dBiases - update to filters and biases
165
175
dWeights , dBiases = HebbianConvSynapse ._compute_update (
166
176
sign_value , w_decay , bias_init , stride , pad_args , delta_shape ,
@@ -180,17 +190,11 @@ def _evolve(opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init,
180
190
weights = jnp .clip (weights , - w_bounds , w_bounds )
181
191
return opt_params , weights , biases , dWeights , dBiases
182
192
183
- @resolver (_evolve )
184
- def evolve (self , opt_params , weights , biases , dWeights , dBiases ):
185
- self .opt_params .set (opt_params )
186
- self .weights .set (weights )
187
- self .biases .set (biases )
188
- self .dWeights .set (dWeights )
189
- self .dBiases .set (dBiases )
190
-
193
+ @transition (output_compartments = ["dInputs" ])
191
194
@staticmethod
192
- def _backtransmit (sign_value , x_size , shape , stride , padding , x_delta_shape ,
193
- antiPad , post , weights ): ## action-backpropagating routine
195
+ def backtransmit (
196
+ sign_value , x_size , shape , stride , padding , x_delta_shape , antiPad , post , weights
197
+ ): ## action-backpropagating routine
194
198
## calc dInputs - adjustment w.r.t. input signal
195
199
k_size , k_size , n_in_chan , n_out_chan = shape
196
200
# antiPad = None
@@ -206,12 +210,9 @@ def _backtransmit(sign_value, x_size, shape, stride, padding, x_delta_shape,
206
210
dInputs = dInputs * sign_value
207
211
return dInputs
208
212
209
- @resolver (_backtransmit )
210
- def backtransmit (self , dInputs ):
211
- self .dInputs .set (dInputs )
212
-
213
+ @transition (output_compartments = ["inputs" , "outputs" , "pre" , "post" , "dInputs" ])
213
214
@staticmethod
214
- def _reset (in_shape , out_shape ):
215
+ def reset (in_shape , out_shape ):
215
216
preVals = jnp .zeros (in_shape )
216
217
postVals = jnp .zeros (out_shape )
217
218
inputs = preVals
@@ -221,14 +222,6 @@ def _reset(in_shape, out_shape):
221
222
dInputs = preVals
222
223
return inputs , outputs , pre , post , dInputs
223
224
224
- @resolver (_reset )
225
- def reset (self , inputs , outputs , pre , post , dInputs ):
226
- self .inputs .set (inputs )
227
- self .outputs .set (outputs )
228
- self .pre .set (pre )
229
- self .post .set (post )
230
- self .dInputs .set (dInputs )
231
-
232
225
@classmethod
233
226
def help (cls ): ## component help function
234
227
properties = {
0 commit comments