Skip to content

Commit 2c75d8f

Browse files
committed
rewrite old synapses with decomposed components
1 parent 1ea4d15 commit 2c75d8f

File tree

8 files changed

+103
-356
lines changed

8 files changed

+103
-356
lines changed

brainpy/_src/dyn/projections/aligns.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -747,15 +747,15 @@ def __init__(
747747

748748
# synapse and delay initialization
749749
if _pre_delay_repr not in self.pre.after_updates:
750-
delay_cls = _init_delay(pre.return_info())
751-
self.pre.after_updates[_pre_delay_repr] = delay_cls
750+
delay_ins = _init_delay(pre.return_info())
751+
self.pre.after_updates[_pre_delay_repr] = delay_ins
752752

753753
# synapse
754754
self._syn_id = f'{str(delay)} / {syn.identifier}'
755755
if self._syn_id not in post.before_updates:
756756
# delay
757-
delay_cls: Delay = pre.after_updates[_pre_delay_repr]
758-
delay_access = DelayAccess(delay_cls, delay)
757+
delay_ins: Delay = pre.after_updates[_pre_delay_repr]
758+
delay_access = DelayAccess(delay_ins, delay)
759759
# synapse
760760
syn_cls = syn()
761761
# add to "after_updates"
@@ -765,8 +765,7 @@ def __init__(
765765
post.cur_inputs[self.name] = out
766766

767767
def update(self):
768-
x = self.post.before_updates[self._syn_id].syn.return_info()
769-
x = _get_return(x)
768+
x = _get_return(self.post.before_updates[self._syn_id].syn.return_info())
770769
current = self.comm(x)
771770
self.post.cur_inputs[self.name].bind_cond(current)
772771
return current

brainpy/_src/dynold/synapses/abstract_models.py

Lines changed: 50 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from brainpy._src.context import share
1010
from brainpy._src.dyn import synapses
1111
from brainpy._src.dyn.base import NeuDyn
12+
from brainpy._src.dnn import linear
1213
from brainpy._src.dynold.synouts import MgBlock, CUBA
1314
from brainpy._src.initialize import Initializer, variable_
1415
from brainpy._src.integrators.ode.generic import odeint
16+
from brainpy._src.dyn.projections.aligns import _pre_delay_repr, _init_delay
1517
from brainpy.types import ArrayType
16-
from .base import TwoEndConn, _SynSTP, _SynOut, _TwoEndConnAlignPre, _DelayedSyn, _init_stp
18+
from .base import TwoEndConn, _SynSTP, _SynOut, _TwoEndConnAlignPre
1719

1820
__all__ = [
1921
'Delta',
@@ -100,12 +102,12 @@ def __init__(
100102
stop_spike_gradient: bool = False,
101103
):
102104
super().__init__(name=name,
103-
pre=pre,
104-
post=post,
105-
conn=conn,
106-
output=output,
107-
stp=stp,
108-
mode=mode)
105+
pre=pre,
106+
post=post,
107+
conn=conn,
108+
output=output,
109+
stp=stp,
110+
mode=mode)
109111

110112
# parameters
111113
self.stop_spike_gradient = stop_spike_gradient
@@ -298,29 +300,40 @@ def __init__(
298300
mode=mode)
299301
# parameters
300302
self.stop_spike_gradient = stop_spike_gradient
301-
self.comp_method = comp_method
302-
self.tau = tau
303-
if bm.size(self.tau) != 1:
304-
raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. But we got {self.tau}')
305303

306-
# connections and weights
307-
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='csr')
304+
# synapse dynamics
305+
self.syn = synapses.Expon(post.varshape, tau=tau, method=method)
306+
307+
# Projection
308+
if isinstance(conn, All2All):
309+
self.comm = linear.AllToAll(pre.num, post.num, g_max)
310+
elif isinstance(conn, One2One):
311+
assert post.num == pre.num
312+
self.comm = linear.OneToOne(pre.num, g_max)
313+
else:
314+
if comp_method == 'dense':
315+
self.comm = linear.MaskedLinear(conn, g_max)
316+
elif comp_method == 'sparse':
317+
if self.stp is None:
318+
self.comm = linear.EventCSRLinear(conn, g_max)
319+
else:
320+
self.comm = linear.CSRLinear(conn, g_max)
321+
else:
322+
raise ValueError(f'Does not support {comp_method}, only "sparse" or "dense".')
308323

309324
# variables
310-
self.g = variable_(bm.zeros, self.post.num, self.mode)
311-
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
325+
self.g = self.syn.g
312326

313-
# function
314-
self.integral = odeint(lambda g, t: -g / self.tau, method=method)
327+
# delay
328+
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
315329

316330
def reset_state(self, batch_size=None):
317-
self.g.value = variable_(bm.zeros, self.post.num, batch_size)
331+
self.syn.reset_state(batch_size)
318332
self.output.reset_state(batch_size)
319-
if self.stp is not None: self.stp.reset_state(batch_size)
333+
if self.stp is not None:
334+
self.stp.reset_state(batch_size)
320335

321336
def update(self, pre_spike=None):
322-
t, dt = share['t'], share['dt']
323-
324337
# delays
325338
if pre_spike is None:
326339
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
@@ -332,52 +345,13 @@ def update(self, pre_spike=None):
332345
self.output.update()
333346
if self.stp is not None:
334347
self.stp.update(pre_spike)
348+
pre_spike = self.stp(pre_spike)
335349

336350
# post values
337-
if isinstance(self.conn, All2All):
338-
syn_value = bm.asarray(pre_spike, dtype=bm.float_)
339-
if self.stp is not None: syn_value = self.stp(syn_value)
340-
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
341-
elif isinstance(self.conn, One2One):
342-
syn_value = bm.asarray(pre_spike, dtype=bm.float_)
343-
if self.stp is not None: syn_value = self.stp(syn_value)
344-
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
345-
else:
346-
if self.comp_method == 'sparse':
347-
f = lambda s: bm.event.csrmv(self.g_max,
348-
self.conn_mask[0],
349-
self.conn_mask[1],
350-
s,
351-
shape=(self.pre.num, self.post.num),
352-
transpose=True)
353-
if isinstance(self.mode, bm.BatchingMode): f = jax.vmap(f)
354-
post_vs = f(pre_spike)
355-
# if not isinstance(self.stp, _NullSynSTP):
356-
# raise NotImplementedError()
357-
else:
358-
syn_value = bm.asarray(pre_spike, dtype=bm.float_)
359-
if self.stp is not None:
360-
syn_value = self.stp(syn_value)
361-
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
362-
# updates
363-
self.g.value = self.integral(self.g.value, t, dt) + post_vs
351+
g = self.syn(self.comm(pre_spike))
364352

365353
# output
366-
return self.output(self.g)
367-
368-
369-
class _DelayedDualExp(_DelayedSyn):
370-
not_desc_params = ('master', 'mode')
371-
372-
def __init__(self, size, keep_size, mode, tau_decay, tau_rise, method, master, stp=None):
373-
syn = synapses.DualExpon(size,
374-
keep_size,
375-
mode=mode,
376-
tau_decay=tau_decay,
377-
tau_rise=tau_rise,
378-
method=method)
379-
stp = _init_stp(stp, master)
380-
super().__init__(syn, stp)
354+
return self.output(g)
381355

382356

383357
class DualExponential(_TwoEndConnAlignPre):
@@ -507,14 +481,12 @@ def __init__(
507481
raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. '
508482
f'But we got {self.tau_decay}')
509483

510-
syn = _DelayedDualExp.desc(pre.size,
511-
pre.keep_size,
512-
mode=mode,
513-
tau_decay=tau_decay,
514-
tau_rise=tau_rise,
515-
method=method,
516-
stp=stp,
517-
master=self)
484+
syn = synapses.DualExpon(pre.size,
485+
pre.keep_size,
486+
mode=mode,
487+
tau_decay=tau_decay,
488+
tau_rise=tau_rise,
489+
method=method, )
518490

519491
super().__init__(pre=pre,
520492
post=post,
@@ -530,7 +502,6 @@ def __init__(
530502

531503
self.check_post_attrs('input')
532504
# copy the references
533-
syn = self.post.before_updates[self.proj._syn_id].syn.syn
534505
self.g = syn.g
535506
self.h = syn.h
536507

@@ -652,21 +623,6 @@ def __init__(
652623
stop_spike_gradient=stop_spike_gradient)
653624

654625

655-
class _DelayedNMDA(_DelayedSyn):
656-
not_desc_params = ('master', 'stp', 'mode')
657-
658-
def __init__(self, size, keep_size, mode, a, tau_decay, tau_rise, method, master, stp=None):
659-
syn = synapses.NMDA(size,
660-
keep_size,
661-
mode=mode,
662-
a=a,
663-
tau_decay=tau_decay,
664-
tau_rise=tau_rise,
665-
method=method)
666-
stp = _init_stp(stp, master)
667-
super().__init__(syn, stp)
668-
669-
670626
class NMDA(_TwoEndConnAlignPre):
671627
r"""NMDA synapse model.
672628
@@ -825,15 +781,13 @@ def __init__(
825781
self.comp_method = comp_method
826782
self.stop_spike_gradient = stop_spike_gradient
827783

828-
syn = _DelayedNMDA.desc(pre.size,
829-
pre.keep_size,
830-
mode=mode,
831-
a=a,
832-
tau_decay=tau_decay,
833-
tau_rise=tau_rise,
834-
method=method,
835-
stp=stp,
836-
master=self)
784+
syn = synapses.NMDA(pre.size,
785+
pre.keep_size,
786+
mode=mode,
787+
a=a,
788+
tau_decay=tau_decay,
789+
tau_rise=tau_rise,
790+
method=method, )
837791

838792
super().__init__(pre=pre,
839793
post=post,
@@ -848,7 +802,6 @@ def __init__(
848802
mode=mode)
849803

850804
# copy the references
851-
syn = self.post.before_updates[self.proj._syn_id].syn.syn
852805
self.g = syn.g
853806
self.x = syn.x
854807

0 commit comments

Comments
 (0)