Skip to content

Commit

Permalink
Lowering precision for old connections
Browse files Browse the repository at this point in the history
  • Loading branch information
n-shevko committed Mar 8, 2025
1 parent 01bc1b0 commit e0f9bf2
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def reset_state_variables(self) -> None:
Contains resetting logic for the connection.
"""

@abstractmethod
def cast_dtype_if_needed(self, w, w_dtype):
if w.dtype != w_dtype:
warnings.warn(f"Provided w has data type {w.dtype} but parameter w_dtype is {w_dtype}")
return w.to(dtype=w_dtype)
else:
return w


class AbstractMulticompartmentConnection(ABC, Module):
# language=rst
Expand Down Expand Up @@ -261,6 +269,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
# language=rst
Expand All @@ -275,6 +284,7 @@ def __init__(
:param reduction: Method for reducing parameter updates along the minibatch
dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor
Keyword arguments:
Expand All @@ -296,9 +306,11 @@ def __init__(
w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax)
else:
w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin)
w = w.to(dtype=w_dtype)
else:
if (self.wmin != -np.inf).any() or (self.wmax != np.inf).any():
w = torch.clamp(torch.as_tensor(w), self.wmin, self.wmax)
w = self.cast_dtype_if_needed(w, w_dtype)

self.w = Parameter(w, requires_grad=False)

Expand Down Expand Up @@ -525,6 +537,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
# language=rst
Expand All @@ -543,6 +556,7 @@ def __init__(
:param reduction: Method for reducing parameter updates along the minibatch
dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor
Keyword arguments:
Expand Down Expand Up @@ -595,9 +609,11 @@ def __init__(
self.out_channels, self.in_channels, self.kernel_size
)
w += self.wmin
w = w.to(dtype=w_dtype)
else:
if (self.wmin == -inf).any() or (self.wmax == inf).any():
w = torch.clamp(w, self.wmin, self.wmax)
w = self.cast_dtype_if_needed(w, w_dtype)

self.w = Parameter(w, requires_grad=False)
self.b = Parameter(
Expand Down Expand Up @@ -667,6 +683,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
# language=rst
Expand All @@ -685,6 +702,7 @@ def __init__(
:param reduction: Method for reducing parameter updates along the minibatch
dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor
Keyword arguments:
Expand Down Expand Up @@ -750,9 +768,11 @@ def __init__(
self.out_channels, self.in_channels, *self.kernel_size
)
w += self.wmin
w = w.to(dtype=w_dtype)
else:
if (self.wmin == -inf).any() or (self.wmax == inf).any():
w = torch.clamp(w, self.wmin, self.wmax)
w = self.cast_dtype_if_needed(w, w_dtype)

self.w = Parameter(w, requires_grad=False)
self.b = Parameter(
Expand Down Expand Up @@ -824,6 +844,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
# language=rst
Expand All @@ -842,6 +863,7 @@ def __init__(
:param reduction: Method for reducing parameter updates along the minibatch
dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor
Keyword arguments:
Expand Down Expand Up @@ -926,9 +948,11 @@ def __init__(
self.out_channels, self.in_channels, *self.kernel_size
)
w += self.wmin
w = w.to(dtype=w_dtype)
else:
if (self.wmin == -inf).any() or (self.wmax == inf).any():
w = torch.clamp(w, self.wmin, self.wmax)
w = self.cast_dtype_if_needed(w, w_dtype)

self.w = Parameter(w, requires_grad=False)
self.b = Parameter(
Expand Down Expand Up @@ -1276,6 +1300,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
# language=rst
Expand All @@ -1299,6 +1324,7 @@ def __init__(
:param reduction: Method for reducing parameter updates along the minibatch
dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor
Keyword arguments:
Expand Down Expand Up @@ -1378,10 +1404,11 @@ def __init__(
w = torch.clamp(w, self.wmin, self.wmax)
else:
w = self.wmin + w * (self.wmax - self.wmin)

w = w.to(dtype=w_dtype)
else:
if (self.wmin != -np.inf).any() or (self.wmax != np.inf).any():
w = torch.clamp(w, self.wmin, self.wmax)
w = self.cast_dtype_if_needed(w, w_dtype)

self.w = Parameter(w, requires_grad=False)

Expand Down Expand Up @@ -1456,6 +1483,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
"""
Expand All @@ -1474,6 +1502,7 @@ def __init__(
In this case, their shape should be the same size as the connection weights.
:param reduction: Method for reducing parameter updates along the minibatch dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor
Keyword arguments:
:param LearningRule update_rule: Modifies connection parameters according to some rule.
:param torch.Tensor w: Strengths of synapses.
Expand Down Expand Up @@ -1507,12 +1536,14 @@ def __init__(
w = torch.rand(
self.in_channels, self.n_filters * self.conv_size, self.kernel_size
)
w = w.to(dtype=w_dtype)
else:
assert w.shape == (
self.in_channels,
self.out_channels * self.conv_size,
self.kernel_size,
), error
w = self.cast_dtype_if_needed(w, w_dtype)

if self.wmin != -np.inf or self.wmax != np.inf:
w = torch.clamp(w, self.wmin, self.wmax)
Expand Down Expand Up @@ -1588,6 +1619,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
"""
Expand All @@ -1606,6 +1638,7 @@ def __init__(
In this case, their shape should be the same size as the connection weights.
:param reduction: Method for reducing parameter updates along the minibatch dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor
Keyword arguments:
:param LearningRule update_rule: Modifies connection parameters according to some rule.
:param torch.Tensor w: Strengths of synapses.
Expand Down Expand Up @@ -1649,12 +1682,14 @@ def __init__(
w = torch.rand(
self.in_channels, self.n_filters * self.conv_prod, self.kernel_prod
)
w = w.to(dtype=w_dtype)
else:
assert w.shape == (
self.in_channels,
self.out_channels * self.conv_prod,
self.kernel_prod,
), error
w = self.cast_dtype_if_needed(w, w_dtype)

if self.wmin != -np.inf or self.wmax != np.inf:
w = torch.clamp(w, self.wmin, self.wmax)
Expand Down Expand Up @@ -1731,6 +1766,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
"""
Expand All @@ -1749,6 +1785,7 @@ def __init__(
In this case, their shape should be the same size as the connection weights.
:param reduction: Method for reducing parameter updates along the minibatch dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor
Keyword arguments:
:param LearningRule update_rule: Modifies connection parameters according to some rule.
:param torch.Tensor w: Strengths of synapses.
Expand Down Expand Up @@ -1794,12 +1831,14 @@ def __init__(
w = torch.rand(
self.in_channels, self.n_filters * self.conv_prod, self.kernel_prod
)
w = w.to(dtype=w_dtype)
else:
assert w.shape == (
self.in_channels,
self.out_channels * self.conv_prod,
self.kernel_prod,
), error
w = self.cast_dtype_if_needed(w, w_dtype)

if self.wmin != -np.inf or self.wmax != np.inf:
w = torch.clamp(w, self.wmin, self.wmax)
Expand Down Expand Up @@ -1875,6 +1914,7 @@ def __init__(
target: Nodes,
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
# language=rst
Expand All @@ -1886,6 +1926,7 @@ def __init__(
accepts a pair of tensors to individualize learning rates of each neuron.
In this case, their shape should be the same size as the connection weights.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor
Keyword arguments:
:param LearningRule update_rule: Modifies connection parameters according to
some rule.
Expand All @@ -1904,10 +1945,11 @@ def __init__(
w = torch.clamp((torch.randn(1)[0] + 1) / 10, self.wmin, self.wmax)
else:
w = self.wmin + ((torch.randn(1)[0] + 1) / 10) * (self.wmax - self.wmin)
w = w.to(dtype=w_dtype)
else:
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
w = torch.clamp(w, self.wmin, self.wmax)

w = self.cast_dtype_if_needed(w, w_dtype)
self.w = Parameter(w, requires_grad=False)

def compute(self, s: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit e0f9bf2

Please sign in to comment.