Skip to content

Commit

Permalink
Trialing replacement of all() with any(), plus beginning adding funct…
Browse files Browse the repository at this point in the history
…ionality for tensor wmin/wmax in SparseConnection
  • Loading branch information
C-Earl committed Jul 17, 2021
1 parent e74a456 commit 78596e3
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,12 @@ def __init__(
w = kwargs.get("w", None)
inf = torch.tensor(np.inf)
if w is None:
if (self.wmin == -inf).all() or (self.wmax == inf).all():
if (self.wmin == -inf).any() or (self.wmax == inf).any():
w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax)
else:
w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin)
else:
if (self.wmin != -inf).all() or (self.wmax != inf).all():
if (self.wmin != -inf).any() or (self.wmax != inf).any():
w = torch.clamp(torch.as_tensor(w), self.wmin, self.wmax)

self.w = Parameter(w, requires_grad=False)
Expand Down Expand Up @@ -753,6 +753,7 @@ def reset_state_variables(self) -> None:
super().reset_state_variables()


# TODO: Potential tensor functionality for 'sparsity' kwarg
class SparseConnection(AbstractConnection):
# language=rst
"""
Expand Down Expand Up @@ -781,7 +782,7 @@ def __init__(
Keyword arguments:
:param torch.Tensor w: Strengths of synapses.
:param torch.Tensor w: Strengths of synapses. Must be in ``torch.sparse`` format
:param float sparsity: Fraction of sparse connections to use.
:param LearningRule update_rule: Modifies connection parameters according to
some rule.
Expand All @@ -805,16 +806,15 @@ def __init__(
i = torch.bernoulli(
1 - self.sparsity * torch.ones(*source.shape, *target.shape)
)
if self.wmin == -np.inf or self.wmax == np.inf:
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
v = torch.clamp(
torch.rand(*source.shape, *target.shape)[i.bool()],
torch.rand(*source.shape, *target.shape),
self.wmin,
self.wmax,
)
)[i.bool()]
else:
v = self.wmin + torch.rand(*source.shape, *target.shape)[i.bool()] * (
self.wmax - self.wmin
)
v = (self.wmin + torch.rand(*source.shape, *target.shape)
* (self.wmax - self.wmin))[i.bool()]
w = torch.sparse.FloatTensor(i.nonzero().t(), v)
elif w is not None and self.sparsity is None:
assert w.is_sparse, "Weight matrix is not sparse (see torch.sparse module)"
Expand All @@ -832,7 +832,8 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
:return: Incoming spikes multiplied by synaptic weights (with or without
decaying spike activation).
"""
return torch.mm(self.w, s.unsqueeze(-1).float()).squeeze(-1)
return torch.mm(self.w, s.view(s.shape[1], 1).float()).squeeze(-1)
# return torch.mm(self.w, s.unsqueeze(-1).float()).squeeze(-1)

def update(self, **kwargs) -> None:
# language=rst
Expand Down

0 comments on commit 78596e3

Please sign in to comment.