Skip to content

Commit

Permalink
Add sparse support for other learning rules
Browse files Browse the repository at this point in the history
  • Loading branch information
n-shevko committed Mar 8, 2025
1 parent bed0623 commit 326d270
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions bindsnet/learning/MCC_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,16 +523,18 @@ def _connection_update(self, **kwargs) -> None:
self.average_buffer_index + 1
) % self.average_update

if self.continues_update:
self.feature_value += self.nu[0] * torch.mean(
self.average_buffer, dim=0
)
elif self.average_buffer_index == 0:
self.feature_value += self.nu[0] * torch.mean(
if self.continues_update or self.average_buffer_index == 0:
update = self.nu[0] * torch.mean(
self.average_buffer, dim=0
)
if self.feature_value.is_sparse:
update = update.to_sparse()
self.feature_value += update
else:
self.feature_value += self.nu[0] * self.reduction(update, dim=0)
update = self.nu[0] * self.reduction(update, dim=0)
if self.feature_value.is_sparse:
update = update.to_sparse()
self.feature_value += update

# Update P^+ and P^- values.
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
Expand Down Expand Up @@ -701,14 +703,16 @@ def _connection_update(self, **kwargs) -> None:
self.average_buffer_index + 1
) % self.average_update

if self.continues_update:
self.feature_value += torch.mean(self.average_buffer, dim=0)
elif self.average_buffer_index == 0:
self.feature_value += torch.mean(self.average_buffer, dim=0)
if self.continues_update or self.average_buffer_index == 0:
update = torch.mean(self.average_buffer, dim=0)
if self.feature_value.is_sparse:
update = update.to_sparse()
self.feature_value += update
else:
self.feature_value += (
self.nu[0] * self.connection.dt * reward * self.eligibility_trace
)
update = self.nu[0] * self.connection.dt * reward * self.eligibility_trace
if self.feature_value.is_sparse:
update = update.to_sparse()
self.feature_value += update

# Update P^+ and P^- values.
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) # Decay
Expand Down

0 comments on commit 326d270

Please sign in to comment.