Skip to content

Commit

Permalink
w_dtype for example
Browse files Browse the repository at this point in the history
  • Loading branch information
n-shevko committed Mar 8, 2025
1 parent e0f9bf2 commit e9dfcbf
Show file tree
Hide file tree
Showing 78 changed files with 98 additions and 24 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ test/models/__pycache__/*
test/network/__pycache__/*
test/analysis/__pycache__/*
*.pyc
**/*.pyc
dist/*
logs/*
.pytest_cache/*
Expand Down
Binary file removed bindsnet/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file removed bindsnet/__pycache__/utils.cpython-310.pyc
Binary file not shown.
Binary file removed bindsnet/analysis/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed bindsnet/datasets/__pycache__/davis.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed bindsnet/encoding/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed bindsnet/learning/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file removed bindsnet/learning/__pycache__/reward.cpython-310.pyc
Binary file not shown.
Binary file removed bindsnet/models/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file removed bindsnet/models/__pycache__/models.cpython-310.pyc
Binary file not shown.
57 changes: 44 additions & 13 deletions bindsnet/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from torch.nn.modules.utils import _pair

from bindsnet.learning import PostPre
from bindsnet.learning.MCC_learning import PostPre as MMCPostPre
from bindsnet.network import Network
from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes
from bindsnet.network.topology import Connection, LocalConnection
from bindsnet.network.topology import Connection, LocalConnection, MulticompartmentConnection
from bindsnet.network.topology_features import Weight


class TwoLayerNetwork(Network):
Expand Down Expand Up @@ -94,6 +96,7 @@ class DiehlAndCook2015(Network):
def __init__(
self,
n_inpt: int,
device: str = "cpu",
n_neurons: int = 100,
exc: float = 22.5,
inh: float = 17.5,
Expand All @@ -102,6 +105,7 @@ def __init__(
reduction: Optional[callable] = None,
wmin: float = 0.0,
wmax: float = 1.0,
w_dtype: torch.dtype = torch.float32,
norm: float = 78.4,
theta_plus: float = 0.05,
tc_theta_decay: float = 1e7,
Expand All @@ -124,6 +128,7 @@ def __init__(
dimension.
:param wmin: Minimum allowed weight on input to excitatory synapses.
:param wmax: Maximum allowed weight on input to excitatory synapses.
:param w_dtype: Data type for :code:`w` tensor
:param norm: Input to excitatory layer connection weights normalization
constant.
:param theta_plus: On-spike increment of ``DiehlAndCookNodes`` membrane
Expand Down Expand Up @@ -170,27 +175,53 @@ def __init__(

# Connections
w = 0.3 * torch.rand(self.n_inpt, self.n_neurons)
input_exc_conn = Connection(
input_exc_conn = MulticompartmentConnection(
source=input_layer,
target=exc_layer,
w=w,
update_rule=PostPre,
nu=nu,
reduction=reduction,
wmin=wmin,
wmax=wmax,
norm=norm,
device=device,
pipeline=[
Weight(
'weight',
w,
value_dtype=w_dtype,
range=[wmin, wmax],
norm=norm,
reduction=reduction,
nu=nu,
learning_rule=MMCPostPre
)
]
)
w = self.exc * torch.diag(torch.ones(self.n_neurons))
exc_inh_conn = Connection(
source=exc_layer, target=inh_layer, w=w, wmin=0, wmax=self.exc
exc_inh_conn = MulticompartmentConnection(
source=exc_layer,
target=inh_layer,
device=device,
pipeline=[
Weight(
'weight',
w,
value_dtype=w_dtype,
range=[0, self.exc]
)
]
)
w = -self.inh * (
torch.ones(self.n_neurons, self.n_neurons)
- torch.diag(torch.ones(self.n_neurons))
)
inh_exc_conn = Connection(
source=inh_layer, target=exc_layer, w=w, wmin=-self.inh, wmax=0
inh_exc_conn = MulticompartmentConnection(
source=inh_layer,
target=exc_layer,
device=device,
pipeline=[
Weight(
'weight',
w,
value_dtype=w_dtype,
range=[-self.inh, 0]
)
]
)

# Add to network
Expand Down
Binary file removed bindsnet/network/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file removed bindsnet/network/__pycache__/network.cpython-310.pyc
Binary file not shown.
Binary file removed bindsnet/network/__pycache__/nodes.cpython-310.pyc
Binary file not shown.
Binary file not shown.
4 changes: 2 additions & 2 deletions bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def reset_state_variables(self) -> None:
Contains resetting logic for the connection.
"""

@abstractmethod
def cast_dtype_if_needed(self, w, w_dtype):
@staticmethod
def cast_dtype_if_needed(w, w_dtype):
if w.dtype != w_dtype:
warnings.warn(f"Provided w has data type {w.dtype} but parameter w_dtype is {w_dtype}")
return w.to(dtype=w_dtype)
Expand Down
50 changes: 42 additions & 8 deletions bindsnet/network/topology_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch
import warnings
from torch import device
from torch.nn import Parameter
import torch.nn.functional as F
Expand All @@ -22,6 +23,7 @@ def __init__(
self,
name: str,
value: Union[torch.Tensor, float, int] = None,
value_dtype: torch.dtype = torch.float32,
range: Optional[Union[list, tuple]] = None,
clamp_frequency: Optional[int] = 1,
norm: Optional[Union[torch.Tensor, float, int]] = None,
Expand All @@ -38,6 +40,7 @@ def __init__(
Instantiates a :code:`Feature` object. Will assign all incoming arguments as class variables
:param name: Name of the feature
:param value: Core numeric object for the feature. This parameters function will vary depending on the feature
:param value_dtype: Data type for :code:`value` tensor
:param range: Range of acceptable values for the :code:`value` parameter
:param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each
sample and after the value has been updated by the learning rule (if there is one)
Expand Down Expand Up @@ -119,6 +122,15 @@ def __init__(
self.assert_valid_range()
if value is not None:
self.assert_feature_in_range()
self.value = self.cast_dtype_if_needed(self.value, value_dtype)

@staticmethod
def cast_dtype_if_needed(value, value_dtype):
if value.dtype != value_dtype:
warnings.warn(f"Provided value has data type {value.dtype} but parameter w_dtype is {value_dtype}")
return value.to(dtype=value_dtype)
else:
return value

@abstractmethod
def reset_state_variables(self) -> None:
Expand Down Expand Up @@ -312,6 +324,7 @@ def __init__(
self,
name: str,
value: Union[torch.Tensor, float, int] = None,
value_dtype: torch.dtype = torch.float32,
range: Optional[Sequence[float]] = None,
norm: Optional[Union[torch.Tensor, float, int]] = None,
learning_rule: Optional[bindsnet.learning.LearningRule] = None,
Expand All @@ -327,6 +340,7 @@ def __init__(
:param value: Number(s) in [0, 1] which represent the probability of a signal traversing a synapse. Tensor values
assume that probabilities will be matched to adjacent synapses in the connection. Scalars will be applied to
all synapses.
:param value_dtype: Data type for :code:`value` tensor
:param range: Range of acceptable values for the :code:`value` parameter. Should be in [0, 1]
:param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample
and after the value has been updated by the learning rule (if there is one)
Expand All @@ -342,6 +356,7 @@ def __init__(
super().__init__(
name=name,
value=value,
value_dtype=value_dtype,
range=[0, 1] if range is None else range,
norm=norm,
learning_rule=learning_rule,
Expand Down Expand Up @@ -419,6 +434,7 @@ def __init__(
super().__init__(
name=name,
value=value,
value_dtype=torch.bool
)

self.name = name
Expand Down Expand Up @@ -497,6 +513,7 @@ def __init__(
self,
name: str,
value: Union[torch.Tensor, float, int] = None,
value_dtype: torch.dtype = torch.float32,
range: Optional[Sequence[float]] = None,
norm: Optional[Union[torch.Tensor, float, int]] = None,
norm_frequency: Optional[str] = "sample",
Expand All @@ -511,6 +528,7 @@ def __init__(
Multiplies signals by scalars
:param name: Name of the feature
:param value: Values to scale signals by
:param value_dtype: Data type for :code:`value` tensor
:param range: Range of acceptable values for the :code:`value` parameter
:param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample
and after the value has been updated by the learning rule (if there is one)
Expand All @@ -530,6 +548,7 @@ def __init__(
super().__init__(
name=name,
value=value,
value_dtype=value_dtype,
range=[-torch.inf, +torch.inf] if range is None else range,
norm=norm,
learning_rule=learning_rule,
Expand Down Expand Up @@ -587,6 +606,7 @@ def __init__(
self,
name: str,
value: Union[torch.Tensor, float, int] = None,
value_dtype: torch.dtype = torch.float32,
range: Optional[Sequence[float]] = None,
norm: Optional[Union[torch.Tensor, float, int]] = None,
) -> None:
Expand All @@ -595,6 +615,7 @@ def __init__(
Adds scalars to signals
:param name: Name of the feature
:param value: Values to add to the signals
:param value_dtype: Data type for :code:`value` tensor
:param range: Range of acceptable values for the :code:`value` parameter
:param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample
and after the value has been updated by the learning rule (if there is one)
Expand All @@ -603,6 +624,7 @@ def __init__(
super().__init__(
name=name,
value=value,
value_dtype=value_dtype,
range=[-torch.inf, +torch.inf] if range is None else range,
norm=norm,
)
Expand All @@ -628,16 +650,18 @@ def __init__(
self,
name: str,
value: Union[torch.Tensor, float, int] = None,
value_dtype: torch.dtype = torch.float32,
range: Optional[Sequence[float]] = None,
) -> None:
# language=rst
"""
Adds scalars to signals
:param name: Name of the feature
:param value: Values to scale signals by
:param value_dtype: Data type for :code:`value` tensor
"""

super().__init__(name=name, value=value, range=range)
super().__init__(name=name, value=value, value_dtype=value_dtype, range=range)

def reset_state_variables(self) -> None:
pass
Expand All @@ -664,6 +688,7 @@ def __init__(
self,
name: str,
value: Union[torch.Tensor, float, int] = None,
value_dtype: torch.dtype = torch.float32,
degrade_function: callable = None,
parent_feature: Optional[AbstractFeature] = None,
) -> None:
Expand All @@ -673,13 +698,14 @@ def __init__(
Note: If :code:`parent_feature` is provided, it will override :code:`value`.
:param name: Name of the feature
:param value: Value used to degrade feature
:param value_dtype: Data type for :code:`value` tensor
:param degrade_function: Callable function which takes a single argument (:code:`value`) and returns a tensor or
constant to be *subtracted* from the propagating spikes.
:param parent_feature: Parent feature with desired :code:`value` to inherit
"""

# Note: parent_feature will override value. See abstract constructor
super().__init__(name=name, value=value, parent_feature=parent_feature)
super().__init__(name=name, value=value, value_dtype=value_dtype, parent_feature=parent_feature)

self.degrade_function = degrade_function

Expand All @@ -695,6 +721,7 @@ def __init__(
self,
name: str,
value: Union[torch.Tensor, float, int] = None,
value_dtype: torch.dtype = torch.float32,
ann_values: Union[list, tuple] = None,
const_update_rate: float = 0.1,
const_decay: float = 0.001,
Expand All @@ -710,6 +737,9 @@ def __init__(
:param const_decay: The spontaneous activation of the synapses.
"""

self.value_dtype = value_dtype
value = value.to(self.value_dtype)

# Define the ANN
class ANN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
Expand Down Expand Up @@ -743,7 +773,7 @@ def forward(self, x):
self.const_update_rate = const_update_rate
self.const_decay = const_decay

super().__init__(name=name, value=value)
super().__init__(name=name, value=value, value_dtype=self.value_dtype)

def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:

Expand All @@ -758,15 +788,15 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
# Update the masks
if self.counter % self.spike_buffer.shape[1] == 0:
with torch.no_grad():
ann_decision = self.ann(self.spike_buffer.to(torch.float32))
ann_decision = self.ann(self.spike_buffer.to(self.value_dtype))
self.mask += (
ann_decision.view(self.mask.shape) * self.const_update_rate
) # update mask with learning rate fraction
self.mask += self.const_decay # spontaneous activate synapses
self.mask = torch.clamp(self.mask, -1, 1) # cap the mask

# self.mask = torch.clamp(self.mask, -1, 1)
self.value = (self.mask > 0).float()
self.value = (self.mask > 0).to(self.value_dtype)

return conn_spikes * self.value

Expand All @@ -785,6 +815,7 @@ def __init__(
self,
name: str,
value: Union[torch.Tensor, float, int] = None,
value_dtype: torch.dtype = torch.float32,
ann_values: Union[list, tuple] = None,
const_update_rate: float = 0.1,
const_decay: float = 0.01,
Expand All @@ -796,9 +827,12 @@ def __init__(
:param name: Name of the feature
:param ann_values: Values to be use to build an ANN that will adapt the connectivity of the layer.
:param value: Values to be use to build an initial mask for the synapses.
:param value_dtype: Data type for :code:`value` tensor
:param const_update_rate: The mask upatate rate of the ANN decision.
:param const_decay: The spontaneous activation of the synapses.
"""
self.value_dtype = value_dtype
value = value.to(self.value_dtype)

# Define the ANN
class ANN(nn.Module):
Expand Down Expand Up @@ -833,7 +867,7 @@ def forward(self, x):
self.const_update_rate = const_update_rate
self.const_decay = const_decay

super().__init__(name=name, value=value)
super().__init__(name=name, value=value, value_dtype=self.value_dtype)

def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:

Expand All @@ -848,15 +882,15 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
# Update the masks
if self.counter % self.spike_buffer.shape[1] == 0:
with torch.no_grad():
ann_decision = self.ann(self.spike_buffer.to(torch.float32))
ann_decision = self.ann(self.spike_buffer.to(self.value_dtype))
self.mask += (
ann_decision.view(self.mask.shape) * self.const_update_rate
) # update mask with learning rate fraction
self.mask += self.const_decay # spontaneous activate synapses
self.mask = torch.clamp(self.mask, -1, 1) # cap the mask

# self.mask = torch.clamp(self.mask, -1, 1)
self.value = (self.mask > 0).float()
self.value = (self.mask > 0).to(self.value_dtype)

return conn_spikes * self.value

Expand Down
Binary file removed bindsnet/pipeline/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file removed bindsnet/pipeline/__pycache__/action.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit e9dfcbf

Please sign in to comment.