Skip to content

Commit 4d0cdff

Browse files
authored
Refactor patched synapse (#110)
* Update patchedSynapse.py n_sub_models is now supporting default value =1 * Update hebbianPatchedSynapse.py n_sub_models is now supporting default value =1
1 parent 087d32a commit 4d0cdff

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

ngclearn/components/synapses/patched/hebbianPatchedSynapse.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class HebbianPatchedSynapse(PatchedSynapse):
124124
shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
125125
with number of inputs by number of outputs)
126126
127-
n_sub_models: The number of submodels in each layer
127+
n_sub_models: The number of submodels in each layer (Default: 1 similar functionality as DenseSynapse)
128128
129129
stride_shape: Stride shape of overlapping synaptic weight value matrix
130130
(Default: (0, 0))
@@ -185,7 +185,7 @@ class HebbianPatchedSynapse(PatchedSynapse):
185185
batch_size: the size of each mini batch
186186
"""
187187

188-
def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), eta=0., weight_init=None, bias_init=None,
188+
def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weight_init=None, bias_init=None,
189189
w_mask=None, w_bound=1., is_nonnegative=False, prior=(None, 0.), sign_value=1.,
190190
optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1.,
191191
resist_scale=1., batch_size=1, **kwargs):

ngclearn/components/synapses/patched/patchedSynapse.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable
6565
shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
6666
with number of inputs by number of outputs)
6767
68-
n_sub_models: The number of submodels in each layer
68+
n_sub_models: The number of submodels in each layer (Default: 1 similar functionality as DenseSynapse)
6969
7070
stride_shape: Stride shape of overlapping synaptic weight value matrix
7171
(Default: (0, 0))
@@ -92,7 +92,7 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable
9292
this to < 1. will result in a sparser synaptic structure
9393
"""
9494

95-
def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), w_mask=None, weight_init=None, bias_init=None,
95+
def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), w_mask=None, weight_init=None, bias_init=None,
9696
resist_scale=1., p_conn=1., batch_size=1, **kwargs):
9797
super().__init__(name, **kwargs)
9898

0 commit comments

Comments
 (0)