Skip to content

Commit bed5331

Browse files
author
Alexander Ororbia
committed
integrated double-exp syn model
1 parent 26f0982 commit bed5331

File tree

8 files changed

+290
-57
lines changed

8 files changed

+290
-57
lines changed
31.1 KB
Loading

docs/source/ngclearn.components.synapses.rst

+8
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ ngclearn.components.synapses.denseSynapse module
3939
:undoc-members:
4040
:show-inheritance:
4141

42+
ngclearn.components.synapses.doubleExpSynapse module
43+
----------------------------------------------------
44+
45+
.. automodule:: ngclearn.components.synapses.doubleExpSynapse
46+
:members:
47+
:undoc-members:
48+
:show-inheritance:
49+
4250
ngclearn.components.synapses.exponentialSynapse module
4351
------------------------------------------------------
4452

docs/tutorials/neurocog/dynamic_synapses.md

+69-38
Large diffs are not rendered by default.

ngclearn/components/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .synapses.hebbian.BCMSynapse import BCMSynapse
4040
from .synapses.STPDenseSynapse import STPDenseSynapse
4141
from .synapses.exponentialSynapse import ExponentialSynapse
42+
from .synapses.doubleExpSynapse import DoupleExpSynapse
4243
from .synapses.alphaSynapse import AlphaSynapse
4344

4445
## point to convolutional component types

ngclearn/components/synapses/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
## short-term plasticity components
66
from .STPDenseSynapse import STPDenseSynapse
77
from .exponentialSynapse import ExponentialSynapse
8+
from .doubleExpSynapse import DoupleExpSynapse
89
from .alphaSynapse import AlphaSynapse
910

1011
## dense synaptic components

ngclearn/components/synapses/alphaSynapse.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ class AlphaSynapse(DenseSynapse): ## dynamic alpha synapse cable
1313
A dynamic alpha synaptic cable; this synapse evolves according to alpha synaptic conductance dynamics.
1414
Specifically, the conductance dynamics are as follows:
1515
16-
| dh/dt = -h/tau_syn + gBar sum_k (t - t_k) // h is an intermediate variable
17-
| dg/dt = -g/tau_syn + h/tau_syn
16+
| dh/dt = -h/tau_decay + gBar sum_k (t - t_k) // h is an intermediate variable
17+
| dg/dt = -g/tau_decay + h/tau_decay
1818
| i_syn = g * (syn_rest - v) // g is `g_syn` and h is `h_syn` in this synapse implementation
1919
| where: syn_rest is the post-synaptic reverse potential for this synapse
2020
| t_k marks time of -pre-synaptic k-th pulse received by post-synaptic unit
@@ -36,7 +36,7 @@ class AlphaSynapse(DenseSynapse): ## dynamic alpha synapse cable
3636
shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
3737
with number of inputs by number of outputs)
3838
39-
tau_syn: synaptic time constant (ms)
39+
tau_decay: synaptic decay time constant (ms)
4040
4141
g_syn_bar: maximum conductance elicited by each incoming spike ("synaptic weight")
4242
@@ -64,12 +64,12 @@ class AlphaSynapse(DenseSynapse): ## dynamic alpha synapse cable
6464

6565
# Define Functions
6666
def __init__(
67-
self, name, shape, tau_syn, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1.,
67+
self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1.,
6868
is_nonplastic=True, **kwargs
6969
):
7070
super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs)
7171
## dynamic synapse meta-parameters
72-
self.tau_syn = tau_syn
72+
self.tau_decay = tau_decay
7373
self.g_syn_bar = g_syn_bar
7474
self.syn_rest = syn_rest ## synaptic resting potential
7575

@@ -87,15 +87,15 @@ def __init__(
8787
@transition(output_compartments=["outputs", "i_syn", "g_syn", "h_syn"])
8888
@staticmethod
8989
def advance_state(
90-
dt, tau_syn, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v
90+
dt, tau_decay, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v
9191
):
9292
s = inputs
9393
## advance conductance variable(s)
9494
_out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron)
95-
dhsyn_dt = -h_syn/tau_syn + _out * g_syn_bar
95+
dhsyn_dt = -h_syn/tau_decay + (_out * g_syn_bar) * (1./dt)
9696
h_syn = h_syn + dhsyn_dt * dt ## run Euler step to move intermediate conductance h
9797

98-
dgsyn_dt = -g_syn/tau_syn + h_syn # or -g_syn/tau_syn + h_syn/tau_syn
98+
dgsyn_dt = -g_syn/tau_decay + h_syn * (1./dt) # or -g_syn/tau_decay + h_syn/tau_decay
9999
g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance g
100100

101101
## compute derive electrical current variable
@@ -159,15 +159,15 @@ def help(cls): ## component help function
159159
"bias_init": "Initialization conditions for bias/base-rate (b) values",
160160
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
161161
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
162-
"tau_syn": "Synaptic time constant (ms)",
162+
"tau_decay": "Conductance decay time constant (ms)",
163163
"g_bar_syn": "Maximum conductance value",
164164
"syn_rest": "Synaptic reversal potential"
165165
}
166166
info = {cls.__name__: properties,
167167
"compartments": compartment_props,
168168
"dynamics": "outputs = g_syn * (v - syn_rest); "
169-
"dhsyn_dt = (W * inputs) * g_syn_bar - h_syn/tau_syn "
170-
"dgsyn_dt = -g_syn/tau_syn + h_syn",
169+
"dhsyn_dt = (W * inputs) * g_syn_bar - h_syn/tau_decay "
170+
"dgsyn_dt = -g_syn/tau_decay + h_syn",
171171
"hyperparameters": hyperparams}
172172
return info
173173

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
from jax import random, numpy as jnp, jit
2+
from ngcsimlib.compilers.process import transition
3+
from ngcsimlib.component import Component
4+
from ngcsimlib.compartment import Compartment
5+
6+
from ngclearn.utils.weight_distribution import initialize_params
7+
from ngcsimlib.logger import info
8+
from ngclearn.components.synapses import DenseSynapse
9+
from ngclearn.utils import tensorstats
10+
11+
class DoupleExpSynapse(DenseSynapse): ## dynamic double-exponential synapse cable
12+
"""
13+
A dynamic double-exponential synaptic cable; this synapse evolves according to difference of two exponentials
14+
synaptic conductance dynamics.
15+
Specifically, the conductance dynamics are as follows:
16+
17+
| dh/dt = -h/tau_rise + gBar sum_k (t - t_k) * (1/tau_rise - 1/tau_decay) // h is an intermediate variable
18+
| dg/dt = -g/tau_decay + h
19+
| i_syn = g * (syn_rest - v) // g is `g_syn` and h is `h_syn` in this synapse implementation
20+
| where: syn_rest is the post-synaptic reverse potential for this synapse
21+
| t_k marks time of -pre-synaptic k-th pulse received by post-synaptic unit
22+
23+
| --- Synapse Compartments: ---
24+
| inputs - input (takes in external signals, e.g., pre-synaptic pulses/spikes)
25+
| outputs - output signals (also equal to i_syn, total electrical current)
26+
| v - coupled voltages from post-synaptic neurons this synaptic cable connects to
27+
| weights - current value matrix of synaptic efficacies
28+
| biases - current value vector of synaptic bias values
29+
| --- Dynamic / Short-term Plasticity Compartments: ---
30+
| g_syn - fixed value matrix of synaptic resources (U)
31+
| i_syn - derived total electrical current variable
32+
33+
Args:
34+
name: the string name of this synapse
35+
36+
shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
37+
with number of inputs by number of outputs)
38+
39+
tau_decay: synaptic decay time constant (ms)
40+
41+
tau_rise: synaptic increase/rise time constant (ms)
42+
43+
g_syn_bar: maximum conductance elicited by each incoming spike ("synaptic weight")
44+
45+
syn_rest: synaptic reversal potential; note, if this is set to `None`, then this
46+
synaptic conductance model will no longer be voltage-dependent (and will ignore
47+
the voltage compartment provided by an external spiking cell)
48+
49+
weight_init: a kernel to drive initialization of this synaptic cable's values;
50+
typically a tuple with 1st element as a string calling the name of
51+
initialization to use
52+
53+
bias_init: a kernel to drive initialization of biases for this synaptic cable
54+
(Default: None, which turns off/disables biases) <unused>
55+
56+
resist_scale: a fixed (resistance) scaling factor to apply to synaptic
57+
transform (Default: 1.), i.e., yields: out = ((W * Rscale) * in)
58+
59+
p_conn: probability of a connection existing (default: 1.); setting
60+
this to < 1 and > 0. will result in a sparser synaptic structure
61+
(lower values yield sparse structure)
62+
63+
is_nonplastic: boolean indicating if this synapse permits plasticity adjustments (Default: True)
64+
65+
"""
66+
67+
# Define Functions
68+
def __init__(
69+
self, name, shape, tau_decay, tau_rise, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1.,
70+
is_nonplastic=True, **kwargs
71+
):
72+
super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs)
73+
## dynamic synapse meta-parameters
74+
self.tau_decay = tau_decay
75+
self.tau_rise = tau_rise
76+
self.g_syn_bar = g_syn_bar
77+
self.syn_rest = syn_rest ## synaptic resting potential
78+
79+
## Set up short-term plasticity / dynamic synapse compartment values
80+
#tmp_key, *subkeys = random.split(self.key.value, 4)
81+
#preVals = jnp.zeros((self.batch_size, shape[0]))
82+
postVals = jnp.zeros((self.batch_size, shape[1]))
83+
self.v = Compartment(postVals) ## coupled voltage (from a post-synaptic neuron)
84+
self.i_syn = Compartment(postVals) ## electrical current output
85+
self.g_syn = Compartment(postVals) ## conductance variable
86+
self.h_syn = Compartment(postVals) ## intermediate conductance variable
87+
if is_nonplastic:
88+
self.weights.set(self.weights.value * 0 + 1.)
89+
90+
@transition(output_compartments=["outputs", "i_syn", "g_syn", "h_syn"])
91+
@staticmethod
92+
def advance_state(
93+
dt, tau_decay, tau_rise, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v
94+
):
95+
s = inputs
96+
#A = tau_decay/(tau_decay - tau_rise) * jnp.power((tau_rise/tau_decay), tau_rise/(tau_rise - tau_decay))
97+
A = 1.
98+
## advance conductance variable(s)
99+
_out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron)
100+
dhsyn_dt = -h_syn/tau_rise + ((_out * g_syn_bar) * (1. / tau_rise - 1. / tau_decay) * A) * (1./dt)
101+
h_syn = h_syn + dhsyn_dt * dt ## run Euler step to move intermediate conductance h
102+
103+
dgsyn_dt = -g_syn/tau_decay + h_syn * (1./dt)
104+
g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance g
105+
106+
## compute derive electrical current variable
107+
i_syn = -g_syn * Rscale
108+
if syn_rest is not None:
109+
i_syn = -(g_syn * Rscale) * (v - syn_rest)
110+
outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases
111+
return outputs, i_syn, g_syn, h_syn
112+
113+
@transition(output_compartments=["inputs", "outputs", "i_syn", "g_syn", "h_syn", "v"])
114+
@staticmethod
115+
def reset(batch_size, shape):
116+
preVals = jnp.zeros((batch_size, shape[0]))
117+
postVals = jnp.zeros((batch_size, shape[1]))
118+
inputs = preVals
119+
outputs = postVals
120+
i_syn = postVals
121+
g_syn = postVals
122+
h_syn = postVals
123+
v = postVals
124+
return inputs, outputs, i_syn, g_syn, h_syn, v
125+
126+
def save(self, directory, **kwargs):
127+
file_name = directory + "/" + self.name + ".npz"
128+
if self.bias_init != None:
129+
jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
130+
else:
131+
jnp.savez(file_name, weights=self.weights.value)
132+
133+
def load(self, directory, **kwargs):
134+
file_name = directory + "/" + self.name + ".npz"
135+
data = jnp.load(file_name)
136+
self.weights.set(data['weights'])
137+
if "biases" in data.keys():
138+
self.biases.set(data['biases'])
139+
140+
@classmethod
141+
def help(cls): ## component help function
142+
properties = {
143+
"synapse_type": "DoubleExpSynapse - performs a synaptic transformation of inputs to produce "
144+
"output signals (e.g., a scaled linear multivariate transformation); "
145+
"this synapse is dynamic, changing according to a difference of exponentials kernel"
146+
}
147+
compartment_props = {
148+
"inputs":
149+
{"inputs": "Takes in external input signal values",
150+
"v" : "Post-synaptic voltage dependence (comes from a wired-to spiking cell)"},
151+
"states":
152+
{"weights": "Synapse efficacy/strength parameter values",
153+
"biases": "Base-rate/bias parameter values",
154+
"g_syn" : "Synaptic conductnace",
155+
"h_syn" : "Intermediate synaptic conductance",
156+
"i_syn" : "Total electrical current",
157+
"key": "JAX PRNG key"},
158+
"outputs":
159+
{"outputs": "Output of synaptic transformation"},
160+
}
161+
hyperparams = {
162+
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
163+
"weight_init": "Initialization conditions for synaptic weight (W) values",
164+
"bias_init": "Initialization conditions for bias/base-rate (b) values",
165+
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
166+
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
167+
"tau_decay": "Conductance decay time constant (ms)",
168+
"tau_rise": "Conductance rise/increase time constant (ms)",
169+
"g_bar_syn": "Maximum conductance value",
170+
"syn_rest": "Synaptic reversal potential"
171+
}
172+
info = {cls.__name__: properties,
173+
"compartments": compartment_props,
174+
"dynamics": "outputs = g_syn * (v - syn_rest); "
175+
"dhsyn_dt = (1/tau_rise - 1/tau_decay) * (W * inputs) * g_syn_bar - h_syn/tau_rise "
176+
"dgsyn_dt = -g_syn/tau_decay + h_syn",
177+
"hyperparameters": hyperparams}
178+
return info
179+
180+
def __repr__(self):
181+
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
182+
maxlen = max(len(c) for c in comps) + 5
183+
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
184+
for c in comps:
185+
stats = tensorstats(getattr(self, c).value)
186+
if stats is not None:
187+
line = [f"{k}: {v}" for k, v in stats.items()]
188+
line = ", ".join(line)
189+
else:
190+
line = "None"
191+
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
192+
return lines

ngclearn/components/synapses/exponentialSynapse.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable
1313
A dynamic exponential synaptic cable; this synapse evolves according to exponential synaptic conductance dynamics.
1414
Specifically, the conductance dynamics are as follows:
1515
16-
| dg/dt = -g/tau_syn + gBar sum_k (t - t_k)
16+
| dg/dt = -g/tau_decay + gBar sum_k (t - t_k)
1717
| i_syn = g * (syn_rest - v) // g is `g_syn` in this synapse implementation
1818
| where: syn_rest is the post-synaptic reverse potential for this synapse
1919
| t_k marks time of -pre-synaptic k-th pulse received by post-synaptic unit
@@ -35,7 +35,7 @@ class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable
3535
shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
3636
with number of inputs by number of outputs)
3737
38-
tau_syn: synaptic time constant (ms)
38+
tau_decay: synaptic decay time constant (ms)
3939
4040
g_syn_bar: maximum conductance elicited by each incoming spike ("synaptic weight")
4141
@@ -63,12 +63,12 @@ class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable
6363

6464
# Define Functions
6565
def __init__(
66-
self, name, shape, tau_syn, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1.,
66+
self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1.,
6767
is_nonplastic=True, **kwargs
6868
):
6969
super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs)
7070
## dynamic synapse meta-parameters
71-
self.tau_syn = tau_syn
71+
self.tau_decay = tau_decay
7272
self.g_syn_bar = g_syn_bar
7373
self.syn_rest = syn_rest ## synaptic resting potential
7474

@@ -85,12 +85,12 @@ def __init__(
8585
@transition(output_compartments=["outputs", "i_syn", "g_syn"])
8686
@staticmethod
8787
def advance_state(
88-
dt, tau_syn, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, v
88+
dt, tau_decay, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, v
8989
):
9090
s = inputs
9191
## advance conductance variable
9292
_out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron)
93-
dgsyn_dt = -g_syn/tau_syn + (_out * g_syn_bar) * (1./dt)
93+
dgsyn_dt = -g_syn/tau_decay + (_out * g_syn_bar) * (1./dt)
9494
g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance
9595
## compute derive electrical current variable
9696
i_syn = -g_syn * Rscale
@@ -152,14 +152,14 @@ def help(cls): ## component help function
152152
"bias_init": "Initialization conditions for bias/base-rate (b) values",
153153
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
154154
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
155-
"tau_syn": "Synaptic time constant (ms)",
155+
"tau_decay": "Conductance decay time constant (ms)",
156156
"g_bar_syn": "Maximum conductance value",
157157
"syn_rest": "Synaptic reversal potential"
158158
}
159159
info = {cls.__name__: properties,
160160
"compartments": compartment_props,
161161
"dynamics": "outputs = g_syn * (v - syn_rest); "
162-
"dgsyn_dt = (W * inputs) * g_syn_bar - g_syn/tau_syn ",
162+
"dgsyn_dt = (W * inputs) * g_syn_bar - g_syn/tau_decay ",
163163
"hyperparameters": hyperparams}
164164
return info
165165

0 commit comments

Comments
 (0)