9
9
from brainpy ._src .context import share
10
10
from brainpy ._src .dyn import synapses
11
11
from brainpy ._src .dyn .base import NeuDyn
12
+ from brainpy ._src .dnn import linear
12
13
from brainpy ._src .dynold .synouts import MgBlock , CUBA
13
14
from brainpy ._src .initialize import Initializer , variable_
14
15
from brainpy ._src .integrators .ode .generic import odeint
16
+ from brainpy ._src .dyn .projections .aligns import _pre_delay_repr , _init_delay
15
17
from brainpy .types import ArrayType
16
- from .base import TwoEndConn , _SynSTP , _SynOut , _TwoEndConnAlignPre , _DelayedSyn , _init_stp
18
+ from .base import TwoEndConn , _SynSTP , _SynOut , _TwoEndConnAlignPre
17
19
18
20
__all__ = [
19
21
'Delta' ,
@@ -100,12 +102,12 @@ def __init__(
100
102
stop_spike_gradient : bool = False ,
101
103
):
102
104
super ().__init__ (name = name ,
103
- pre = pre ,
104
- post = post ,
105
- conn = conn ,
106
- output = output ,
107
- stp = stp ,
108
- mode = mode )
105
+ pre = pre ,
106
+ post = post ,
107
+ conn = conn ,
108
+ output = output ,
109
+ stp = stp ,
110
+ mode = mode )
109
111
110
112
# parameters
111
113
self .stop_spike_gradient = stop_spike_gradient
@@ -298,29 +300,40 @@ def __init__(
298
300
mode = mode )
299
301
# parameters
300
302
self .stop_spike_gradient = stop_spike_gradient
301
- self .comp_method = comp_method
302
- self .tau = tau
303
- if bm .size (self .tau ) != 1 :
304
- raise ValueError (f'"tau" must be a scalar or a tensor with size of 1. But we got { self .tau } ' )
305
303
306
- # connections and weights
307
- self .g_max , self .conn_mask = self ._init_weights (g_max , comp_method , sparse_data = 'csr' )
304
+ # synapse dynamics
305
+ self .syn = synapses .Expon (post .varshape , tau = tau , method = method )
306
+
307
+ # Projection
308
+ if isinstance (conn , All2All ):
309
+ self .comm = linear .AllToAll (pre .num , post .num , g_max )
310
+ elif isinstance (conn , One2One ):
311
+ assert post .num == pre .num
312
+ self .comm = linear .OneToOne (pre .num , g_max )
313
+ else :
314
+ if comp_method == 'dense' :
315
+ self .comm = linear .MaskedLinear (conn , g_max )
316
+ elif comp_method == 'sparse' :
317
+ if self .stp is None :
318
+ self .comm = linear .EventCSRLinear (conn , g_max )
319
+ else :
320
+ self .comm = linear .CSRLinear (conn , g_max )
321
+ else :
322
+ raise ValueError (f'Does not support { comp_method } , only "sparse" or "dense".' )
308
323
309
324
# variables
310
- self .g = variable_ (bm .zeros , self .post .num , self .mode )
311
- self .delay_step = self .register_delay (f"{ self .pre .name } .spike" , delay_step , self .pre .spike )
325
+ self .g = self .syn .g
312
326
313
- # function
314
- self .integral = odeint ( lambda g , t : - g / self .tau , method = method )
327
+ # delay
328
+ self .delay_step = self . register_delay ( f" { self . pre . name } .spike" , delay_step , self .pre . spike )
315
329
316
330
def reset_state (self , batch_size = None ):
317
- self .g . value = variable_ ( bm . zeros , self . post . num , batch_size )
331
+ self .syn . reset_state ( batch_size )
318
332
self .output .reset_state (batch_size )
319
- if self .stp is not None : self .stp .reset_state (batch_size )
333
+ if self .stp is not None :
334
+ self .stp .reset_state (batch_size )
320
335
321
336
def update (self , pre_spike = None ):
322
- t , dt = share ['t' ], share ['dt' ]
323
-
324
337
# delays
325
338
if pre_spike is None :
326
339
pre_spike = self .get_delay_data (f"{ self .pre .name } .spike" , self .delay_step )
@@ -332,52 +345,13 @@ def update(self, pre_spike=None):
332
345
self .output .update ()
333
346
if self .stp is not None :
334
347
self .stp .update (pre_spike )
348
+ pre_spike = self .stp (pre_spike )
335
349
336
350
# post values
337
- if isinstance (self .conn , All2All ):
338
- syn_value = bm .asarray (pre_spike , dtype = bm .float_ )
339
- if self .stp is not None : syn_value = self .stp (syn_value )
340
- post_vs = self ._syn2post_with_all2all (syn_value , self .g_max )
341
- elif isinstance (self .conn , One2One ):
342
- syn_value = bm .asarray (pre_spike , dtype = bm .float_ )
343
- if self .stp is not None : syn_value = self .stp (syn_value )
344
- post_vs = self ._syn2post_with_one2one (syn_value , self .g_max )
345
- else :
346
- if self .comp_method == 'sparse' :
347
- f = lambda s : bm .event .csrmv (self .g_max ,
348
- self .conn_mask [0 ],
349
- self .conn_mask [1 ],
350
- s ,
351
- shape = (self .pre .num , self .post .num ),
352
- transpose = True )
353
- if isinstance (self .mode , bm .BatchingMode ): f = jax .vmap (f )
354
- post_vs = f (pre_spike )
355
- # if not isinstance(self.stp, _NullSynSTP):
356
- # raise NotImplementedError()
357
- else :
358
- syn_value = bm .asarray (pre_spike , dtype = bm .float_ )
359
- if self .stp is not None :
360
- syn_value = self .stp (syn_value )
361
- post_vs = self ._syn2post_with_dense (syn_value , self .g_max , self .conn_mask )
362
- # updates
363
- self .g .value = self .integral (self .g .value , t , dt ) + post_vs
351
+ g = self .syn (self .comm (pre_spike ))
364
352
365
353
# output
366
- return self .output (self .g )
367
-
368
-
369
- class _DelayedDualExp (_DelayedSyn ):
370
- not_desc_params = ('master' , 'mode' )
371
-
372
- def __init__ (self , size , keep_size , mode , tau_decay , tau_rise , method , master , stp = None ):
373
- syn = synapses .DualExpon (size ,
374
- keep_size ,
375
- mode = mode ,
376
- tau_decay = tau_decay ,
377
- tau_rise = tau_rise ,
378
- method = method )
379
- stp = _init_stp (stp , master )
380
- super ().__init__ (syn , stp )
354
+ return self .output (g )
381
355
382
356
383
357
class DualExponential (_TwoEndConnAlignPre ):
@@ -507,14 +481,12 @@ def __init__(
507
481
raise ValueError (f'"tau_decay" must be a scalar or a tensor with size of 1. '
508
482
f'But we got { self .tau_decay } ' )
509
483
510
- syn = _DelayedDualExp .desc (pre .size ,
511
- pre .keep_size ,
512
- mode = mode ,
513
- tau_decay = tau_decay ,
514
- tau_rise = tau_rise ,
515
- method = method ,
516
- stp = stp ,
517
- master = self )
484
+ syn = synapses .DualExpon (pre .size ,
485
+ pre .keep_size ,
486
+ mode = mode ,
487
+ tau_decay = tau_decay ,
488
+ tau_rise = tau_rise ,
489
+ method = method , )
518
490
519
491
super ().__init__ (pre = pre ,
520
492
post = post ,
@@ -530,7 +502,6 @@ def __init__(
530
502
531
503
self .check_post_attrs ('input' )
532
504
# copy the references
533
- syn = self .post .before_updates [self .proj ._syn_id ].syn .syn
534
505
self .g = syn .g
535
506
self .h = syn .h
536
507
@@ -652,21 +623,6 @@ def __init__(
652
623
stop_spike_gradient = stop_spike_gradient )
653
624
654
625
655
- class _DelayedNMDA (_DelayedSyn ):
656
- not_desc_params = ('master' , 'stp' , 'mode' )
657
-
658
- def __init__ (self , size , keep_size , mode , a , tau_decay , tau_rise , method , master , stp = None ):
659
- syn = synapses .NMDA (size ,
660
- keep_size ,
661
- mode = mode ,
662
- a = a ,
663
- tau_decay = tau_decay ,
664
- tau_rise = tau_rise ,
665
- method = method )
666
- stp = _init_stp (stp , master )
667
- super ().__init__ (syn , stp )
668
-
669
-
670
626
class NMDA (_TwoEndConnAlignPre ):
671
627
r"""NMDA synapse model.
672
628
@@ -825,15 +781,13 @@ def __init__(
825
781
self .comp_method = comp_method
826
782
self .stop_spike_gradient = stop_spike_gradient
827
783
828
- syn = _DelayedNMDA .desc (pre .size ,
829
- pre .keep_size ,
830
- mode = mode ,
831
- a = a ,
832
- tau_decay = tau_decay ,
833
- tau_rise = tau_rise ,
834
- method = method ,
835
- stp = stp ,
836
- master = self )
784
+ syn = synapses .NMDA (pre .size ,
785
+ pre .keep_size ,
786
+ mode = mode ,
787
+ a = a ,
788
+ tau_decay = tau_decay ,
789
+ tau_rise = tau_rise ,
790
+ method = method , )
837
791
838
792
super ().__init__ (pre = pre ,
839
793
post = post ,
@@ -848,7 +802,6 @@ def __init__(
848
802
mode = mode )
849
803
850
804
# copy the references
851
- syn = self .post .before_updates [self .proj ._syn_id ].syn .syn
852
805
self .g = syn .g
853
806
self .x = syn .x
854
807
0 commit comments