1
1
from jax import random , numpy as jnp , jit
2
- from ngclearn import resolver , Component , Compartment
2
+ from ngcsimlib .compilers .process import transition
3
+ from ngcsimlib .component import Component
4
+ from ngcsimlib .compartment import Compartment
5
+
3
6
from .deconvSynapse import DeconvSynapse
7
+ from ngclearn .utils .weight_distribution import initialize_params
8
+ from ngcsimlib .logger import info
9
+ from ngclearn .utils import tensorstats
10
+ import ngclearn .utils .weight_distribution as dist
4
11
from ngclearn .components .synapses .convolution .ngcconv import (deconv2d , _calc_dX_deconv ,
5
12
_calc_dK_deconv , calc_dX_deconv ,
6
13
calc_dK_deconv )
@@ -79,13 +86,15 @@ class HebbianDeconvSynapse(DeconvSynapse): ## Hebbian-evolved deconvolutional ca
79
86
"""
80
87
81
88
# Define Functions
82
- def __init__ (self , name , shape , x_shape , eta = 0. , filter_init = None , bias_init = None ,
83
- stride = 1 , padding = None , resist_scale = 1. , w_bound = 0. , is_nonnegative = False ,
84
- w_decay = 0. , sign_value = 1. , optim_type = "sgd" , batch_size = 1 , ** kwargs ):
85
- super ().__init__ (name , shape , x_shape = x_shape , filter_init = filter_init ,
86
- bias_init = bias_init , resist_scale = resist_scale ,
87
- stride = stride , padding = padding , batch_size = batch_size ,
88
- ** kwargs )
89
+ def __init__ (
90
+ self , name , shape , x_shape , eta = 0. , filter_init = None , bias_init = None , stride = 1 , padding = None ,
91
+ resist_scale = 1. , w_bound = 0. , is_nonnegative = False , w_decay = 0. , sign_value = 1. , optim_type = "sgd" ,
92
+ batch_size = 1 , ** kwargs
93
+ ):
94
+ super ().__init__ (
95
+ name , shape , x_shape = x_shape , filter_init = filter_init , bias_init = bias_init , resist_scale = resist_scale ,
96
+ stride = stride , padding = padding , batch_size = batch_size , ** kwargs
97
+ )
89
98
90
99
self .eta = eta
91
100
self .w_bounds = w_bound
@@ -112,8 +121,7 @@ def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=Non
112
121
[self .weights .value , self .biases .value ]
113
122
if bias_init else [self .weights .value ]))
114
123
115
- def _init (self , batch_size , x_size , shape , stride , padding , pad_args ,
116
- weights ):
124
+ def _init (self , batch_size , x_size , shape , stride , padding , pad_args , weights ):
117
125
k_size , k_size , n_in_chan , n_out_chan = shape
118
126
_x = jnp .zeros ((batch_size , x_size , x_size , n_in_chan ))
119
127
_d = deconv2d (_x , self .weights .value , stride_size = self .stride ,
@@ -132,8 +140,7 @@ def _init(self, batch_size, x_size, shape, stride, padding, pad_args,
132
140
self .x_delta_shape = (dx , dy )
133
141
134
142
@staticmethod
135
- def _compute_update (sign_value , w_decay , bias_init , shape , stride , padding ,
136
- delta_shape , pre , post , weights ):
143
+ def _compute_update (sign_value , w_decay , bias_init , shape , stride , padding , delta_shape , pre , post , weights ):
137
144
k_size , k_size , n_in_chan , n_out_chan = shape
138
145
## compute adjustment to filters
139
146
dWeights = calc_dK_deconv (pre , post , delta_shape = delta_shape ,
@@ -148,10 +155,12 @@ def _compute_update(sign_value, w_decay, bias_init, shape, stride, padding,
148
155
dBiases = jnp .sum (post , axis = 0 , keepdims = True ) * sign_value
149
156
return dWeights , dBiases
150
157
158
+ @transition (output_compartments = ["opt_params" , "weights" , "biases" , "dWeights" , "dBiases" ])
151
159
@staticmethod
152
- def _evolve (opt , sign_value , w_decay , w_bounds , is_nonnegative , bias_init ,
153
- shape , stride , padding , delta_shape , pre , post , weights , biases ,
154
- opt_params ):
160
+ def evolve (
161
+ opt , sign_value , w_decay , w_bounds , is_nonnegative , bias_init , shape , stride , padding , delta_shape ,
162
+ pre , post , weights , biases , opt_params
163
+ ):
155
164
dWeights , dBiases = HebbianDeconvSynapse ._compute_update (
156
165
sign_value , w_decay , bias_init , shape , stride , padding , delta_shape ,
157
166
pre , post , weights
@@ -169,30 +178,19 @@ def _evolve(opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init,
169
178
weights = jnp .clip (weights , - w_bounds , w_bounds )
170
179
return opt_params , weights , biases , dWeights , dBiases
171
180
172
- @resolver (_evolve )
173
- def evolve (self , opt_params , weights , biases , dWeights , dBiases ):
174
- self .opt_params .set (opt_params )
175
- self .weights .set (weights )
176
- self .biases .set (biases )
177
- self .dWeights .set (dWeights )
178
- self .dBiases .set (dBiases )
179
-
181
+ @transition (output_compartments = ["dInputs" ])
180
182
@staticmethod
181
- def _backtransmit (sign_value , stride , padding , x_delta_shape , pre , post ,
182
- weights ): ## action-backpropagating routine
183
+ def backtransmit (sign_value , stride , padding , x_delta_shape , pre , post , weights ): ## action-backpropagating routine
183
184
## calc dInputs
184
185
dInputs = calc_dX_deconv (weights , post , delta_shape = x_delta_shape ,
185
186
stride_size = stride , padding = padding )
186
187
## flip sign of back-transmitted signal (if applicable)
187
188
dInputs = dInputs * sign_value
188
189
return dInputs
189
190
190
- @resolver (_backtransmit )
191
- def backtransmit (self , dInputs ):
192
- self .dInputs .set (dInputs )
193
-
191
+ @transition (output_compartments = ["inputs" , "outputs" , "pre" , "post" , "dInputs" ])
194
192
@staticmethod
195
- def _reset (in_shape , out_shape ):
193
+ def reset (in_shape , out_shape ):
196
194
preVals = jnp .zeros (in_shape )
197
195
postVals = jnp .zeros (out_shape )
198
196
inputs = preVals
@@ -202,14 +200,6 @@ def _reset(in_shape, out_shape):
202
200
dInputs = preVals
203
201
return inputs , outputs , pre , post , dInputs
204
202
205
- @resolver (_reset )
206
- def reset (self , inputs , outputs , pre , post , dInputs ):
207
- self .inputs .set (inputs )
208
- self .outputs .set (outputs )
209
- self .pre .set (pre )
210
- self .post .set (post )
211
- self .dInputs .set (dInputs )
212
-
213
203
@classmethod
214
204
def help (cls ): ## component help function
215
205
properties = {
0 commit comments