@@ -291,9 +291,9 @@ def __init__(
291
291
:param torch.Tensor w: Strengths of synapses.
292
292
:param torch.Tensor b: Target population bias.
293
293
:param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or
294
- tensor of same shape/ size as (target.shape[0], source.shape[0], *kernel_size)
294
+ tensor of same size as w
295
295
:param Union[float, torch.Tensor] wmax: Maximum allowed value(s) on the connection weights. Single value, or
296
- tensor of same shape/ size as (target.shape[0], source.shape[0], *kernel_size)
296
+ tensor of same size as w
297
297
:param float norm: Total weight per target neuron normalization constant.
298
298
"""
299
299
super ().__init__ (source , target , nu , reduction , weight_decay , ** kwargs )
@@ -495,6 +495,7 @@ def reset_state_variables(self) -> None:
495
495
self .firing_rates = torch .zeros (self .source .s .shape )
496
496
497
497
498
+ # TODO: Add wmin/wmax tensor functionality to this one
498
499
class LocalConnection (AbstractConnection ):
499
500
# language=rst
500
501
"""
@@ -694,21 +695,23 @@ def __init__(
694
695
Keyword arguments:
695
696
:param LearningRule update_rule: Modifies connection parameters according to
696
697
some rule.
697
- :param torch.Tensor w: Strengths of synapses.
698
- :param float wmin: Minimum allowed value on the connection weights.
699
- :param float wmax: Maximum allowed value on the connection weights.
698
+ :param Union[float, torch.Tensor] w: Strengths of synapses. Can be single value or tensor of size ``target``
699
+ :param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or
700
+ tensor of same size as w
701
+ :param Union[float, torch.Tensor] wmax: Maximum allowed value(s) on the connection weights. Single value, or
702
+ tensor of same size as w
700
703
:param float norm: Total weight per target neuron normalization constant.
701
704
"""
702
705
super ().__init__ (source , target , nu , weight_decay , ** kwargs )
703
706
704
707
w = kwargs .get ("w" , None )
705
708
if w is None :
706
- if self .wmin == - np .inf or self .wmax == np .inf :
709
+ if ( self .wmin == - np .inf ). all () or ( self .wmax == np .inf ). all () :
707
710
w = torch .clamp ((torch .randn (1 )[0 ] + 1 ) / 10 , self .wmin , self .wmax )
708
711
else :
709
712
w = self .wmin + ((torch .randn (1 )[0 ] + 1 ) / 10 ) * (self .wmax - self .wmin )
710
713
else :
711
- if self .wmin != - np .inf or self .wmax != np .inf :
714
+ if ( self .wmin == - np .inf ). all () or ( self .wmax == np .inf ). all () :
712
715
w = torch .clamp (w , self .wmin , self .wmax )
713
716
714
717
self .w = Parameter (w , requires_grad = False )
0 commit comments