Skip to content

Commit e74a456

Browse files
committed
Added functionality for tensor wmin/wmax for MeanFieldConnection (plus a small clarification on its documentation)
1 parent f8e5768 commit e74a456

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

bindsnet/network/topology.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,9 @@ def __init__(
291291
:param torch.Tensor w: Strengths of synapses.
292292
:param torch.Tensor b: Target population bias.
293293
:param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or
294-
tensor of same shape/size as (target.shape[0], source.shape[0], *kernel_size)
294+
tensor of same size as w
295295
:param Union[float, torch.Tensor] wmax: Maximum allowed value(s) on the connection weights. Single value, or
296-
tensor of same shape/size as (target.shape[0], source.shape[0], *kernel_size)
296+
tensor of same size as w
297297
:param float norm: Total weight per target neuron normalization constant.
298298
"""
299299
super().__init__(source, target, nu, reduction, weight_decay, **kwargs)
@@ -495,6 +495,7 @@ def reset_state_variables(self) -> None:
495495
self.firing_rates = torch.zeros(self.source.s.shape)
496496

497497

498+
# TODO: Add wmin/wmax tensor functionality to this one
498499
class LocalConnection(AbstractConnection):
499500
# language=rst
500501
"""
@@ -694,21 +695,23 @@ def __init__(
694695
Keyword arguments:
695696
:param LearningRule update_rule: Modifies connection parameters according to
696697
some rule.
697-
:param torch.Tensor w: Strengths of synapses.
698-
:param float wmin: Minimum allowed value on the connection weights.
699-
:param float wmax: Maximum allowed value on the connection weights.
698+
:param Union[float, torch.Tensor] w: Strengths of synapses. Can be single value or tensor of size ``target``
699+
:param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or
700+
tensor of same size as w
701+
:param Union[float, torch.Tensor] wmax: Maximum allowed value(s) on the connection weights. Single value, or
702+
tensor of same size as w
700703
:param float norm: Total weight per target neuron normalization constant.
701704
"""
702705
super().__init__(source, target, nu, weight_decay, **kwargs)
703706

704707
w = kwargs.get("w", None)
705708
if w is None:
706-
if self.wmin == -np.inf or self.wmax == np.inf:
709+
if (self.wmin == -np.inf).all() or (self.wmax == np.inf).all():
707710
w = torch.clamp((torch.randn(1)[0] + 1) / 10, self.wmin, self.wmax)
708711
else:
709712
w = self.wmin + ((torch.randn(1)[0] + 1) / 10) * (self.wmax - self.wmin)
710713
else:
711-
if self.wmin != -np.inf or self.wmax != np.inf:
714+
if (self.wmin == -np.inf).all() or (self.wmax == np.inf).all():
712715
w = torch.clamp(w, self.wmin, self.wmax)
713716

714717
self.w = Parameter(w, requires_grad=False)

0 commit comments

Comments
 (0)