Skip to content

Commit bc713f6

Browse files
committed
update patched synapses and their test cases
1 parent 8317d3f commit bc713f6

File tree

4 files changed

+168
-34
lines changed

4 files changed

+168
-34
lines changed

ngclearn/components/synapses/patched/hebbianPatchedSynapse.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ngclearn import resolver, Component, Compartment
66
from ngclearn.components.synapses.patched import PatchedSynapse
77
from ngclearn.utils import tensorstats
8+
from ngcsimlib.compilers.process import transition
89

910
@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9])
1011
def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1.,
@@ -240,8 +241,9 @@ def _compute_update(w_mask, w_bound, is_nonnegative, sign_value, prior_type, pri
240241

241242
return dW * jnp.where(0 != jnp.abs(weights), 1, 0) , db
242243

244+
@transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"])
243245
@staticmethod
244-
def _evolve(w_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
246+
def evolve(w_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
245247
post_wght, bias_init, pre, post, weights, biases, opt_params):
246248
## calculate synaptic update values
247249
dWeights, dBiases = HebbianPatchedSynapse._compute_update(
@@ -258,16 +260,9 @@ def _evolve(w_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_
258260
weights = _enforce_constraints(weights, w_mask, w_bound, is_nonnegative=is_nonnegative)
259261
return opt_params, weights, biases, dWeights, dBiases
260262

261-
@resolver(_evolve)
262-
def evolve(self, opt_params, weights, biases, dWeights, dBiases):
263-
self.opt_params.set(opt_params)
264-
self.weights.set(weights)
265-
self.biases.set(biases)
266-
self.dWeights.set(dWeights)
267-
self.dBiases.set(dBiases)
268-
263+
@transition(output_compartments=["inputs", "outputs", "pre", "post", "dWeights", "dBiases"])
269264
@staticmethod
270-
def _reset(batch_size, shape):
265+
def reset(batch_size, shape):
271266
preVals = jnp.zeros((batch_size, shape[0]))
272267
postVals = jnp.zeros((batch_size, shape[1]))
273268
return (
@@ -280,19 +275,6 @@ def _reset(batch_size, shape):
280275
)
281276

282277

283-
284-
@resolver(_reset)
285-
def reset(self, inputs, outputs, pre, post, dWeights, dBiases):
286-
self.inputs.set(inputs)
287-
self.outputs.set(outputs)
288-
self.pre.set(pre)
289-
self.post.set(post)
290-
self.dWeights.set(dWeights)
291-
self.dBiases.set(dBiases)
292-
293-
294-
295-
296278
@classmethod
297279
def help(cls): ## component help function
298280
properties = {

ngclearn/components/synapses/patched/patchedSynapse.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ngclearn import resolver, Component, Compartment
44
from ngclearn.components.jaxComponent import JaxComponent
55
from ngclearn.utils import tensorstats
6+
from ngcsimlib.compilers.process import transition
67
from ngclearn.utils.weight_distribution import initialize_params
78
from ngcsimlib.logger import info
89
import math
@@ -135,28 +136,22 @@ def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), w_mask=None, w
135136
self.biases = Compartment(initialize_params(subkeys[2], bias_init,
136137
(1, self.shape[1]))
137138
if bias_init else 0.0)
139+
140+
@transition(output_compartments=["outputs"])
138141
@staticmethod
139-
def _advance_state(Rscale, inputs, weights, biases):
142+
def advance_state(Rscale, inputs, weights, biases):
140143
outputs = (jnp.matmul(inputs, weights) * Rscale) + biases
141144
return outputs
142145

143-
@resolver(_advance_state)
144-
def advance_state(self, outputs):
145-
self.outputs.set(outputs)
146-
146+
@transition(output_compartments=["inputs", "outputs"])
147147
@staticmethod
148-
def _reset(batch_size, shape):
148+
def reset(batch_size, shape):
149149
preVals = jnp.zeros((batch_size, shape[0]))
150150
postVals = jnp.zeros((batch_size, shape[1]))
151151
inputs = preVals
152152
outputs = postVals
153153
return inputs, outputs
154154

155-
@resolver(_reset)
156-
def reset(self, inputs, outputs):
157-
self.inputs.set(inputs)
158-
self.outputs.set(outputs)
159-
160155
def save(self, directory, **kwargs):
161156
file_name = directory + "/" + self.name + ".npz"
162157
if self.bias_init != None:
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# %%
2+
3+
from jax import numpy as jnp, random, jit
4+
from ngcsimlib.context import Context
5+
import numpy as np
6+
np.random.seed(42)
7+
from ngclearn.components import HebbianPatchedSynapse
8+
from ngcsimlib.compilers import compile_command, wrap_command
9+
from numpy.testing import assert_array_equal
10+
11+
from ngcsimlib.compilers.process import Process, transition
12+
from ngcsimlib.component import Component
13+
from ngcsimlib.compartment import Compartment
14+
from ngcsimlib.context import Context
15+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
16+
17+
18+
def test_hebbianPatchedSynapse():
19+
np.random.seed(42)
20+
name = "hebbian_patched_synapse_ctx"
21+
dkey = random.PRNGKey(42)
22+
dkey, *subkeys = random.split(dkey, 100)
23+
dt = 1. # ms
24+
25+
# model hyper
26+
shape = (10, 5)
27+
n_sub_models = 2
28+
stride_shape = (1, 1)
29+
batch_size = 1
30+
resist_scale = 1.0
31+
32+
with Context(name) as ctx:
33+
a = HebbianPatchedSynapse(
34+
name="a",
35+
shape=shape,
36+
n_sub_models=n_sub_models,
37+
stride_shape=stride_shape,
38+
resist_scale=resist_scale,
39+
batch_size=batch_size
40+
)
41+
42+
advance_process = (Process() >> a.advance_state)
43+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
44+
reset_process = (Process() >> a.reset)
45+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
46+
evolve_process = (Process() >> a.evolve)
47+
ctx.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
48+
49+
# Compile and add commands
50+
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
51+
# ctx.add_command(wrap_command(jit(reset_cmd)), name="reset")
52+
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
53+
# ctx.add_command(wrap_command(jit(advance_cmd)), name="run")
54+
# evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
55+
# ctx.add_command(wrap_command(jit(evolve_cmd)), name="evolve")
56+
57+
@Context.dynamicCommand
58+
def clamp_inputs(x):
59+
a.inputs.set(x)
60+
61+
@Context.dynamicCommand
62+
def clamp_pre(x):
63+
a.pre.set(x)
64+
65+
@Context.dynamicCommand
66+
def clamp_post(x):
67+
a.post.set(x)
68+
69+
a.weights.set(jnp.ones((12, 12)) * 0.5)
70+
71+
in_pre = jnp.ones((10, 12)) * 1.0
72+
in_post = jnp.ones((10, 12)) * 0.75
73+
74+
ctx.reset()
75+
clamp_pre(in_pre)
76+
clamp_post(in_post)
77+
ctx.run(t=1. * dt, dt=dt)
78+
ctx.evolve(t=1. * dt, dt=dt)
79+
80+
print(a.weights.value)
81+
82+
# Basic assertions to check learning dynamics
83+
assert a.weights.value.shape == (12, 12), ""
84+
assert a.weights.value[0, 0] == 0.5, ""
85+
86+
87+
88+
# test_hebbianPatchedSynapse()
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# %%
2+
3+
from jax import numpy as jnp, random, jit
4+
from ngcsimlib.context import Context
5+
import numpy as np
6+
np.random.seed(42)
7+
from ngclearn.components import PatchedSynapse
8+
from ngcsimlib.compilers import compile_command, wrap_command
9+
from numpy.testing import assert_array_equal
10+
11+
from ngcsimlib.compilers.process import Process, transition
12+
from ngcsimlib.component import Component
13+
from ngcsimlib.compartment import Compartment
14+
from ngcsimlib.context import Context
15+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
16+
17+
18+
def test_patchedSynapse():
19+
np.random.seed(42)
20+
name = "patched_synapse_ctx"
21+
dkey = random.PRNGKey(42)
22+
dkey, *subkeys = random.split(dkey, 100)
23+
dt = 1. # ms
24+
# model hyper
25+
shape = (10, 5)
26+
n_sub_models = 2
27+
stride_shape = (1, 1)
28+
batch_size = 1
29+
resist_scale = 1.0
30+
with Context(name) as ctx:
31+
a = PatchedSynapse(
32+
name="a",
33+
shape=shape,
34+
n_sub_models=n_sub_models,
35+
stride_shape=stride_shape,
36+
resist_scale=resist_scale,
37+
batch_size=batch_size,
38+
weight_init={"dist": "gaussian", "std": 0.1},
39+
bias_init={"dist": "constant", "value": 0.0}
40+
)
41+
42+
advance_process = (Process() >> a.advance_state)
43+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
44+
reset_process = (Process() >> a.reset)
45+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
46+
47+
# Compile and add commands
48+
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
49+
# ctx.add_command(wrap_command(jit(reset_cmd)), name="reset")
50+
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
51+
# ctx.add_command(wrap_command(jit(advance_cmd)), name="run")
52+
53+
@Context.dynamicCommand
54+
def clamp_inputs(x):
55+
a.inputs.set(x)
56+
57+
inputs_seq = jnp.asarray(np.random.randn(1, 12))
58+
weights = a.weights.value
59+
biases = a.biases.value
60+
expected_outputs = (jnp.matmul(inputs_seq, weights) * resist_scale) + biases
61+
outputs_outs = []
62+
ctx.reset()
63+
ctx.clamp_inputs(inputs_seq)
64+
ctx.run(t=0., dt=dt)
65+
outputs_outs.append(a.outputs.value)
66+
outputs_outs = jnp.concatenate(outputs_outs, axis=1)
67+
# Verify outputs match expected values
68+
np.testing.assert_allclose(outputs_outs, expected_outputs, atol=1e-5)
69+

0 commit comments

Comments
 (0)