1
1
from jax import random , numpy as jnp , jit
2
- from ngclearn import resolver , Component , Compartment
2
+ from ngcsimlib .compilers .process import transition
3
+ from ngcsimlib .component import Component
4
+ from ngcsimlib .compartment import Compartment
5
+
3
6
from .convSynapse import ConvSynapse
7
+ from ngclearn .utils .weight_distribution import initialize_params
8
+ from ngcsimlib .logger import info
9
+ from ngclearn .utils import tensorstats
10
+ import ngclearn .utils .weight_distribution as dist
4
11
from ngclearn .components .synapses .convolution .ngcconv import (_conv_same_transpose_padding ,
5
12
_conv_valid_transpose_padding )
6
13
from ngclearn .components .synapses .convolution .ngcconv import (conv2d , _calc_dX_conv ,
@@ -67,12 +74,14 @@ class TraceSTDPConvSynapse(ConvSynapse): ## trace-based STDP convolutional cable
67
74
"""
68
75
69
76
# Define Functions
70
- def __init__ (self , name , shape , x_shape , A_plus , A_minus , eta = 0. ,
71
- pretrace_target = 0. , filter_init = None , stride = 1 , padding = None ,
72
- resist_scale = 1. , w_bound = 0. , w_decay = 0. , batch_size = 1 , ** kwargs ):
73
- super ().__init__ (name , shape , x_shape = x_shape , filter_init = filter_init ,
74
- bias_init = None , resist_scale = resist_scale , stride = stride ,
75
- padding = padding , batch_size = batch_size , ** kwargs )
77
+ def __init__ (
78
+ self , name , shape , x_shape , A_plus , A_minus , eta = 0. , pretrace_target = 0. , filter_init = None , stride = 1 ,
79
+ padding = None , resist_scale = 1. , w_bound = 0. , w_decay = 0. , batch_size = 1 , ** kwargs
80
+ ):
81
+ super ().__init__ (
82
+ name , shape , x_shape = x_shape , filter_init = filter_init , bias_init = None , resist_scale = resist_scale ,
83
+ stride = stride , padding = padding , batch_size = batch_size , ** kwargs
84
+ )
76
85
77
86
self .eta = eta
78
87
self .w_bound = w_bound ## soft weight constraint
@@ -107,8 +116,7 @@ def __init__(self, name, shape, x_shape, A_plus, A_minus, eta=0.,
107
116
self .x_size , k_size , stride )
108
117
########################################################################
109
118
110
- def _init (self , batch_size , x_size , shape , stride , padding , pad_args ,
111
- weights ):
119
+ def _init (self , batch_size , x_size , shape , stride , padding , pad_args , weights ):
112
120
k_size , k_size , n_in_chan , n_out_chan = shape
113
121
_x = jnp .zeros ((batch_size , x_size , x_size , n_in_chan ))
114
122
_d = conv2d (_x , weights .value , stride_size = stride , padding = padding ) * 0
@@ -126,26 +134,28 @@ def _init(self, batch_size, x_size, shape, stride, padding, pad_args,
126
134
self .x_delta_shape = (dx , dy )
127
135
128
136
@staticmethod
129
- def _compute_update (pretrace_target , Aplus , Aminus , stride , pad_args ,
130
- delta_shape , preSpike , preTrace , postSpike , postTrace ):
137
+ def _compute_update (
138
+ pretrace_target , Aplus , Aminus , stride , pad_args , delta_shape , preSpike , preTrace , postSpike , postTrace
139
+ ):
131
140
## Compute long-term potentiation to filters
132
- dW_ltp = calc_dK_conv (preTrace - pretrace_target , postSpike * Aplus ,
133
- delta_shape = delta_shape , stride_size = stride ,
134
- padding = pad_args )
141
+ dW_ltp = calc_dK_conv (
142
+ preTrace - pretrace_target , postSpike * Aplus , delta_shape = delta_shape , stride_size = stride , padding = pad_args
143
+ )
135
144
## Compute long-term depression to filters
136
- dW_ltd = - calc_dK_conv (preSpike , postTrace * Aminus ,
137
- delta_shape = delta_shape , stride_size = stride ,
138
- padding = pad_args )
145
+ dW_ltd = - calc_dK_conv (
146
+ preSpike , postTrace * Aminus , delta_shape = delta_shape , stride_size = stride , padding = pad_args
147
+ )
139
148
dWeights = (dW_ltp + dW_ltd )
140
149
return dWeights
141
150
151
+ @transition (output_compartments = ["weights" , "dWeights" ])
142
152
@staticmethod
143
- def _evolve (pretrace_target , Aplus , Aminus , w_decay , w_bound ,
144
- stride , pad_args , delta_shape , preSpike , preTrace , postSpike ,
145
- postTrace , weights , eta ):
153
+ def evolve (
154
+ pretrace_target , Aplus , Aminus , w_decay , w_bound , stride , pad_args , delta_shape , preSpike , preTrace ,
155
+ postSpike , postTrace , weights , eta
156
+ ):
146
157
dWeights = TraceSTDPConvSynapse ._compute_update (
147
- pretrace_target , Aplus , Aminus , stride , pad_args , delta_shape ,
148
- preSpike , preTrace , postSpike , postTrace
158
+ pretrace_target , Aplus , Aminus , stride , pad_args , delta_shape , preSpike , preTrace , postSpike , postTrace
149
159
)
150
160
if w_decay > 0. : ## apply synaptic decay
151
161
weights = weights + dWeights * eta - weights * w_decay ## conduct decayed STDP-ascent
@@ -157,14 +167,11 @@ def _evolve(pretrace_target, Aplus, Aminus, w_decay, w_bound,
157
167
weights = jnp .clip (weights , eps , w_bound - eps )
158
168
return weights , dWeights
159
169
160
- @resolver (_evolve )
161
- def evolve (self , weights , dWeights ):
162
- self .weights .set (weights )
163
- self .dWeights .set (dWeights )
164
-
170
+ @transition (output_compartments = ["dInputs" ])
165
171
@staticmethod
166
- def _backtransmit (x_size , shape , stride , padding , x_delta_shape , antiPad ,
167
- postSpike , weights ): ## action-backpropagating routine
172
+ def backtransmit (
173
+ x_size , shape , stride , padding , x_delta_shape , antiPad , postSpike , weights
174
+ ): ## action-backpropagating routine
168
175
## calc dInputs - adjustment w.r.t. input signal
169
176
k_size , k_size , n_in_chan , n_out_chan = shape
170
177
# antiPad = None
@@ -174,16 +181,12 @@ def _backtransmit(x_size, shape, stride, padding, x_delta_shape, antiPad,
174
181
# elif padding == "VALID":
175
182
# antiPad = _conv_valid_transpose_padding(postSpike.shape[1], x_size,
176
183
# k_size, stride)
177
- dInputs = calc_dX_conv (weights , postSpike , delta_shape = x_delta_shape ,
178
- stride_size = stride , anti_padding = antiPad )
184
+ dInputs = calc_dX_conv (weights , postSpike , delta_shape = x_delta_shape , stride_size = stride , anti_padding = antiPad )
179
185
return dInputs
180
186
181
- @resolver (_backtransmit )
182
- def backtransmit (self , dInputs ):
183
- self .dInputs .set (dInputs )
184
-
187
+ @transition (output_compartments = ["inputs" , "outputs" , "preSpike" , "postSpike" , "preTrace" , "postTrace" ])
185
188
@staticmethod
186
- def _reset (in_shape , out_shape ):
189
+ def reset (in_shape , out_shape ):
187
190
preVals = jnp .zeros (in_shape )
188
191
postVals = jnp .zeros (out_shape )
189
192
inputs = preVals
@@ -194,15 +197,6 @@ def _reset(in_shape, out_shape):
194
197
postTrace = postVals
195
198
return inputs , outputs , preSpike , postSpike , preTrace , postTrace
196
199
197
- @resolver (_reset )
198
- def reset (self , inputs , outputs , preSpike , postSpike , preTrace , postTrace ):
199
- self .inputs .set (inputs )
200
- self .outputs .set (outputs )
201
- self .preSpike .set (preSpike )
202
- self .postSpike .set (postSpike )
203
- self .preTrace .set (preTrace )
204
- self .postTrace .set (postTrace )
205
-
206
200
@classmethod
207
201
def help (cls ): ## component help function
208
202
properties = {
0 commit comments