diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index eb44627f..2432011f 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -408,7 +408,7 @@ def compute(self, s: torch.Tensor) -> torch.Tensor: decaying spike activation). """ self.firing_rates -= self.decay * self.firing_rates - self.firing_rates += s.float() + self.firing_rates += s.float().squeeze() _, indices = F.max_pool2d( self.firing_rates,