|
| 1 | +from jax import random, numpy as jnp, jit |
| 2 | +from ngcsimlib.compilers.process import transition |
| 3 | +from ngcsimlib.component import Component |
| 4 | +from ngcsimlib.compartment import Compartment |
| 5 | + |
| 6 | +from ngclearn.utils.weight_distribution import initialize_params |
| 7 | +from ngcsimlib.logger import info |
| 8 | +from ngclearn.components.synapses import DenseSynapse |
| 9 | +from ngclearn.utils import tensorstats |
| 10 | + |
| 11 | +class DoupleExpSynapse(DenseSynapse): ## dynamic double-exponential synapse cable |
| 12 | + """ |
| 13 | + A dynamic double-exponential synaptic cable; this synapse evolves according to difference of two exponentials |
| 14 | + synaptic conductance dynamics. |
| 15 | + Specifically, the conductance dynamics are as follows: |
| 16 | +
|
| 17 | + | dh/dt = -h/tau_rise + gBar sum_k (t - t_k) * (1/tau_rise - 1/tau_decay) // h is an intermediate variable |
| 18 | + | dg/dt = -g/tau_decay + h |
| 19 | + | i_syn = g * (syn_rest - v) // g is `g_syn` and h is `h_syn` in this synapse implementation |
| 20 | + | where: syn_rest is the post-synaptic reverse potential for this synapse |
| 21 | + | t_k marks time of -pre-synaptic k-th pulse received by post-synaptic unit |
| 22 | +
|
| 23 | + | --- Synapse Compartments: --- |
| 24 | + | inputs - input (takes in external signals, e.g., pre-synaptic pulses/spikes) |
| 25 | + | outputs - output signals (also equal to i_syn, total electrical current) |
| 26 | + | v - coupled voltages from post-synaptic neurons this synaptic cable connects to |
| 27 | + | weights - current value matrix of synaptic efficacies |
| 28 | + | biases - current value vector of synaptic bias values |
| 29 | + | --- Dynamic / Short-term Plasticity Compartments: --- |
| 30 | + | g_syn - fixed value matrix of synaptic resources (U) |
| 31 | + | i_syn - derived total electrical current variable |
| 32 | +
|
| 33 | + Args: |
| 34 | + name: the string name of this synapse |
| 35 | +
|
| 36 | + shape: tuple specifying shape of this synaptic cable (usually a 2-tuple |
| 37 | + with number of inputs by number of outputs) |
| 38 | +
|
| 39 | + tau_decay: synaptic decay time constant (ms) |
| 40 | +
|
| 41 | + tau_rise: synaptic increase/rise time constant (ms) |
| 42 | +
|
| 43 | + g_syn_bar: maximum conductance elicited by each incoming spike ("synaptic weight") |
| 44 | +
|
| 45 | + syn_rest: synaptic reversal potential; note, if this is set to `None`, then this |
| 46 | + synaptic conductance model will no longer be voltage-dependent (and will ignore |
| 47 | + the voltage compartment provided by an external spiking cell) |
| 48 | +
|
| 49 | + weight_init: a kernel to drive initialization of this synaptic cable's values; |
| 50 | + typically a tuple with 1st element as a string calling the name of |
| 51 | + initialization to use |
| 52 | +
|
| 53 | + bias_init: a kernel to drive initialization of biases for this synaptic cable |
| 54 | + (Default: None, which turns off/disables biases) <unused> |
| 55 | +
|
| 56 | + resist_scale: a fixed (resistance) scaling factor to apply to synaptic |
| 57 | + transform (Default: 1.), i.e., yields: out = ((W * Rscale) * in) |
| 58 | +
|
| 59 | + p_conn: probability of a connection existing (default: 1.); setting |
| 60 | + this to < 1 and > 0. will result in a sparser synaptic structure |
| 61 | + (lower values yield sparse structure) |
| 62 | +
|
| 63 | + is_nonplastic: boolean indicating if this synapse permits plasticity adjustments (Default: True) |
| 64 | +
|
| 65 | + """ |
| 66 | + |
| 67 | + # Define Functions |
| 68 | + def __init__( |
| 69 | + self, name, shape, tau_decay, tau_rise, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., |
| 70 | + is_nonplastic=True, **kwargs |
| 71 | + ): |
| 72 | + super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs) |
| 73 | + ## dynamic synapse meta-parameters |
| 74 | + self.tau_decay = tau_decay |
| 75 | + self.tau_rise = tau_rise |
| 76 | + self.g_syn_bar = g_syn_bar |
| 77 | + self.syn_rest = syn_rest ## synaptic resting potential |
| 78 | + |
| 79 | + ## Set up short-term plasticity / dynamic synapse compartment values |
| 80 | + #tmp_key, *subkeys = random.split(self.key.value, 4) |
| 81 | + #preVals = jnp.zeros((self.batch_size, shape[0])) |
| 82 | + postVals = jnp.zeros((self.batch_size, shape[1])) |
| 83 | + self.v = Compartment(postVals) ## coupled voltage (from a post-synaptic neuron) |
| 84 | + self.i_syn = Compartment(postVals) ## electrical current output |
| 85 | + self.g_syn = Compartment(postVals) ## conductance variable |
| 86 | + self.h_syn = Compartment(postVals) ## intermediate conductance variable |
| 87 | + if is_nonplastic: |
| 88 | + self.weights.set(self.weights.value * 0 + 1.) |
| 89 | + |
| 90 | + @transition(output_compartments=["outputs", "i_syn", "g_syn", "h_syn"]) |
| 91 | + @staticmethod |
| 92 | + def advance_state( |
| 93 | + dt, tau_decay, tau_rise, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v |
| 94 | + ): |
| 95 | + s = inputs |
| 96 | + #A = tau_decay/(tau_decay - tau_rise) * jnp.power((tau_rise/tau_decay), tau_rise/(tau_rise - tau_decay)) |
| 97 | + A = 1. |
| 98 | + ## advance conductance variable(s) |
| 99 | + _out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron) |
| 100 | + dhsyn_dt = -h_syn/tau_rise + ((_out * g_syn_bar) * (1. / tau_rise - 1. / tau_decay) * A) * (1./dt) |
| 101 | + h_syn = h_syn + dhsyn_dt * dt ## run Euler step to move intermediate conductance h |
| 102 | + |
| 103 | + dgsyn_dt = -g_syn/tau_decay + h_syn * (1./dt) |
| 104 | + g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance g |
| 105 | + |
| 106 | + ## compute derive electrical current variable |
| 107 | + i_syn = -g_syn * Rscale |
| 108 | + if syn_rest is not None: |
| 109 | + i_syn = -(g_syn * Rscale) * (v - syn_rest) |
| 110 | + outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases |
| 111 | + return outputs, i_syn, g_syn, h_syn |
| 112 | + |
| 113 | + @transition(output_compartments=["inputs", "outputs", "i_syn", "g_syn", "h_syn", "v"]) |
| 114 | + @staticmethod |
| 115 | + def reset(batch_size, shape): |
| 116 | + preVals = jnp.zeros((batch_size, shape[0])) |
| 117 | + postVals = jnp.zeros((batch_size, shape[1])) |
| 118 | + inputs = preVals |
| 119 | + outputs = postVals |
| 120 | + i_syn = postVals |
| 121 | + g_syn = postVals |
| 122 | + h_syn = postVals |
| 123 | + v = postVals |
| 124 | + return inputs, outputs, i_syn, g_syn, h_syn, v |
| 125 | + |
| 126 | + def save(self, directory, **kwargs): |
| 127 | + file_name = directory + "/" + self.name + ".npz" |
| 128 | + if self.bias_init != None: |
| 129 | + jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value) |
| 130 | + else: |
| 131 | + jnp.savez(file_name, weights=self.weights.value) |
| 132 | + |
| 133 | + def load(self, directory, **kwargs): |
| 134 | + file_name = directory + "/" + self.name + ".npz" |
| 135 | + data = jnp.load(file_name) |
| 136 | + self.weights.set(data['weights']) |
| 137 | + if "biases" in data.keys(): |
| 138 | + self.biases.set(data['biases']) |
| 139 | + |
| 140 | + @classmethod |
| 141 | + def help(cls): ## component help function |
| 142 | + properties = { |
| 143 | + "synapse_type": "DoubleExpSynapse - performs a synaptic transformation of inputs to produce " |
| 144 | + "output signals (e.g., a scaled linear multivariate transformation); " |
| 145 | + "this synapse is dynamic, changing according to a difference of exponentials kernel" |
| 146 | + } |
| 147 | + compartment_props = { |
| 148 | + "inputs": |
| 149 | + {"inputs": "Takes in external input signal values", |
| 150 | + "v" : "Post-synaptic voltage dependence (comes from a wired-to spiking cell)"}, |
| 151 | + "states": |
| 152 | + {"weights": "Synapse efficacy/strength parameter values", |
| 153 | + "biases": "Base-rate/bias parameter values", |
| 154 | + "g_syn" : "Synaptic conductnace", |
| 155 | + "h_syn" : "Intermediate synaptic conductance", |
| 156 | + "i_syn" : "Total electrical current", |
| 157 | + "key": "JAX PRNG key"}, |
| 158 | + "outputs": |
| 159 | + {"outputs": "Output of synaptic transformation"}, |
| 160 | + } |
| 161 | + hyperparams = { |
| 162 | + "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", |
| 163 | + "weight_init": "Initialization conditions for synaptic weight (W) values", |
| 164 | + "bias_init": "Initialization conditions for bias/base-rate (b) values", |
| 165 | + "resist_scale": "Resistance level scaling factor (applied to output of transformation)", |
| 166 | + "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", |
| 167 | + "tau_decay": "Conductance decay time constant (ms)", |
| 168 | + "tau_rise": "Conductance rise/increase time constant (ms)", |
| 169 | + "g_bar_syn": "Maximum conductance value", |
| 170 | + "syn_rest": "Synaptic reversal potential" |
| 171 | + } |
| 172 | + info = {cls.__name__: properties, |
| 173 | + "compartments": compartment_props, |
| 174 | + "dynamics": "outputs = g_syn * (v - syn_rest); " |
| 175 | + "dhsyn_dt = (1/tau_rise - 1/tau_decay) * (W * inputs) * g_syn_bar - h_syn/tau_rise " |
| 176 | + "dgsyn_dt = -g_syn/tau_decay + h_syn", |
| 177 | + "hyperparameters": hyperparams} |
| 178 | + return info |
| 179 | + |
| 180 | + def __repr__(self): |
| 181 | + comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] |
| 182 | + maxlen = max(len(c) for c in comps) + 5 |
| 183 | + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" |
| 184 | + for c in comps: |
| 185 | + stats = tensorstats(getattr(self, c).value) |
| 186 | + if stats is not None: |
| 187 | + line = [f"{k}: {v}" for k, v in stats.items()] |
| 188 | + line = ", ".join(line) |
| 189 | + else: |
| 190 | + line = "None" |
| 191 | + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" |
| 192 | + return lines |
0 commit comments