Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encoding -> single #484

Merged
merged 1 commit into from
May 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 10 additions & 48 deletions bindsnet/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ def plot_spikes(
for i, datum in enumerate(spikes.items()):
spikes = (
datum[1][
time[0] : time[1],
n_neurons[datum[0]][0] : n_neurons[datum[0]][1],
time[0] : time[1], n_neurons[datum[0]][0] : n_neurons[datum[0]][1]
]
.detach()
.clone()
Expand Down Expand Up @@ -144,19 +143,13 @@ def plot_spikes(
for ax in axes:
ax.set_aspect("auto")

plt.setp(
axes,
xticks=[],
xlabel="Simulation time",
ylabel="Neuron index",
)
plt.setp(axes, xticks=[], xlabel="Simulation time", ylabel="Neuron index")
plt.tight_layout()
else:
for i, datum in enumerate(spikes.items()):
spikes = (
datum[1][
time[0] : time[1],
n_neurons[datum[0]][0] : n_neurons[datum[0]][1],
time[0] : time[1], n_neurons[datum[0]][0] : n_neurons[datum[0]][1]
]
.detach()
.clone()
Expand Down Expand Up @@ -424,10 +417,7 @@ def plot_assignments(
else:
color = plt.get_cmap("RdBu", len(classes) + 1)
im = ax.matshow(
locals_assignments,
cmap=color,
vmin=-1.5,
vmax=len(classes) - 0.5,
locals_assignments, cmap=color, vmin=-1.5, vmax=len(classes) - 0.5
)

div = make_axes_locatable(ax)
Expand Down Expand Up @@ -616,9 +606,7 @@ def plot_voltages(
):
ims.append(
axes.axhline(
y=thresholds[v[0]].item(),
c="r",
linestyle="--",
y=thresholds[v[0]].item(), c="r", linestyle="--"
)
)
else:
Expand All @@ -635,13 +623,7 @@ def plot_voltages(
)
)

args = (
v[0],
n_neurons[v[0]][0],
n_neurons[v[0]][1],
time[0],
time[1],
)
args = (v[0], n_neurons[v[0]][0], n_neurons[v[0]][1], time[0], time[1])
plt.title("%s voltages for neurons (%d - %d) from t = %d to %d " % args)
plt.xlabel("Time (ms)")

Expand Down Expand Up @@ -670,9 +652,7 @@ def plot_voltages(
):
ims.append(
axes[i].axhline(
y=thresholds[v[0]].item(),
c="r",
linestyle="--",
y=thresholds[v[0]].item(), c="r", linestyle="--"
)
)
else:
Expand All @@ -688,13 +668,7 @@ def plot_voltages(
cmap=cmap,
)
)
args = (
v[0],
n_neurons[v[0]][0],
n_neurons[v[0]][1],
time[0],
time[1],
)
args = (v[0], n_neurons[v[0]][0], n_neurons[v[0]][1], time[0], time[1])
axes[i].set_title(
"%s voltages for neurons (%d - %d) from t = %d to %d " % args
)
Expand Down Expand Up @@ -736,13 +710,7 @@ def plot_voltages(
.T,
cmap=cmap,
)
args = (
v[0],
n_neurons[v[0]][0],
n_neurons[v[0]][1],
time[0],
time[1],
)
args = (v[0], n_neurons[v[0]][0], n_neurons[v[0]][1], time[0], time[1])
axes.set_title(
"%s voltages for neurons (%d - %d) from t = %d to %d " % args
)
Expand Down Expand Up @@ -776,13 +744,7 @@ def plot_voltages(
.T,
cmap=cmap,
)
args = (
v[0],
n_neurons[v[0]][0],
n_neurons[v[0]][1],
time[0],
time[1],
)
args = (v[0], n_neurons[v[0]][0], n_neurons[v[0]][1], time[0], time[1])
axes[i].set_title(
"%s voltages for neurons (%d - %d) from t = %d to %d " % args
)
Expand Down
5 changes: 1 addition & 4 deletions bindsnet/datasets/alov300.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,7 @@ def get_sample(self, idx):
bbox_curr_gt = BoundingBox(currbb[0], currbb[1], currbb[2], currbb[3])
bbox_gt_recentered = BoundingBox(0, 0, 0, 0)
bbox_gt_recentered = bbox_curr_gt.recenter(
rand_search_location,
edge_spacing_x,
edge_spacing_y,
bbox_gt_recentered,
rand_search_location, edge_spacing_x, edge_spacing_y, bbox_gt_recentered
)
curr_sample["image"] = rand_search_region
curr_sample["bb"] = bbox_gt_recentered.get_bb_list()
Expand Down
26 changes: 8 additions & 18 deletions bindsnet/datasets/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,7 @@ def shift_crop_training_sample(sample, bb_params):
bbox_curr_gt = BoundingBox(currbb[0], currbb[1], currbb[2], currbb[3])
bbox_gt_recentered = BoundingBox(0, 0, 0, 0)
bbox_gt_recentered = bbox_curr_gt.recenter(
rand_search_location,
edge_spacing_x,
edge_spacing_y,
bbox_gt_recentered,
rand_search_location, edge_spacing_x, edge_spacing_y, bbox_gt_recentered
)
output_sample["image"] = rand_search_region
output_sample["bb"] = bbox_gt_recentered.get_bb_list()
Expand All @@ -155,12 +152,9 @@ def crop_sample(sample):
opts = {}
image, bb = sample["image"], sample["bb"]
orig_bbox = BoundingBox(bb[0], bb[1], bb[2], bb[3])
(
output_image,
pad_image_location,
edge_spacing_x,
edge_spacing_y,
) = cropPadImage(orig_bbox, image)
(output_image, pad_image_location, edge_spacing_x, edge_spacing_y) = cropPadImage(
orig_bbox, image
)
new_bbox = BoundingBox(0, 0, 0, 0)
new_bbox = new_bbox.recenter(
pad_image_location, edge_spacing_x, edge_spacing_y, new_bbox
Expand Down Expand Up @@ -198,8 +192,7 @@ def cropPadImage(bbox_tight, image):
output_height = max(math.ceil(bbox_tight.compute_output_height()), roi_height)
if image.ndim > 2:
output_image = np.zeros(
(int(output_height), int(output_width), image.shape[2]),
dtype=image.dtype,
(int(output_height), int(output_width), image.shape[2]), dtype=image.dtype
)
else:
output_image = np.zeros(
Expand Down Expand Up @@ -392,8 +385,7 @@ def shift(
):
if shift_motion_model:
width_scale_factor = max(
min_scale,
min(max_scale, sample_exp_two_sides(lambda_scale_frac)),
min_scale, min(max_scale, sample_exp_two_sides(lambda_scale_frac))
)
else:
rand_num = sample_rand_uniform()
Expand All @@ -410,8 +402,7 @@ def shift(
):
if shift_motion_model:
height_scale_factor = max(
min_scale,
min(max_scale, sample_exp_two_sides(lambda_scale_frac)),
min_scale, min(max_scale, sample_exp_two_sides(lambda_scale_frac))
)
else:
rand_num = sample_rand_uniform()
Expand Down Expand Up @@ -464,8 +455,7 @@ def shift(
new_y_temp = center_y + rand_num * (2 * new_height) - new_height

new_center_y = min(
image.shape[0] - new_height / 2,
max(new_height / 2, new_y_temp),
image.shape[0] - new_height / 2, max(new_height / 2, new_y_temp)
)
first_time_y = False
num_tries_y = num_tries_y + 1
Expand Down
9 changes: 4 additions & 5 deletions bindsnet/encoding/encodings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional

import torch
import numpy as np


def single(
Expand All @@ -27,10 +26,10 @@ def single(
"""
time = int(time / dt)
shape = list(datum.shape)
datum = np.copy(datum)
quantile = np.quantile(datum, 1 - sparsity)
s = np.zeros([time, *shape], device=device)
s[0] = np.where(datum > quantile, np.ones(shape), np.zeros(shape))
datum = torch.tensor(datum)
quantile = torch.quantile(datum, 1 - sparsity)
s = torch.zeros([time, *shape], device=device)
s[0] = torch.where(datum > quantile, torch.ones(shape), torch.zeros(shape))
return torch.Tensor(s).byte()


Expand Down
4 changes: 1 addition & 3 deletions bindsnet/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,7 @@ def __init__(
w = w / w.max()
w = (w * self.max_inhib) + self.start_inhib
recurrent_output_conn = Connection(
source=self.layers["Y"],
target=self.layers["Y"],
w=w,
source=self.layers["Y"], target=self.layers["Y"], w=w
)
self.add_connection(recurrent_output_conn, source="Y", target="Y")

Expand Down
3 changes: 1 addition & 2 deletions bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,7 @@ def __init__(

self.w = Parameter(w, requires_grad=False)
self.b = Parameter(
kwargs.get("b", torch.zeros(self.out_channels)),
requires_grad=False,
kwargs.get("b", torch.zeros(self.out_channels)), requires_grad=False
)

def compute(self, s: torch.Tensor) -> torch.Tensor:
Expand Down
10 changes: 2 additions & 8 deletions examples/mnist/batch_eth_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@
from bindsnet import ROOT_DIR
from bindsnet.datasets import MNIST, DataLoader
from bindsnet.encoding import PoissonEncoder
from bindsnet.evaluation import (
all_activity,
proportion_weighting,
assign_labels,
)
from bindsnet.evaluation import all_activity, proportion_weighting, assign_labels
from bindsnet.models import DiehlAndCook2015
from bindsnet.network.monitors import Monitor
from bindsnet.utils import get_square_weights, get_square_assignments
Expand Down Expand Up @@ -201,9 +197,7 @@

# Get network predictions.
all_activity_pred = all_activity(
spikes=spike_record,
assignments=assignments,
n_labels=n_classes,
spikes=spike_record, assignments=assignments, n_labels=n_classes
)
proportion_pred = proportion_weighting(
spikes=spike_record,
Expand Down
18 changes: 3 additions & 15 deletions examples/mnist/supervised_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
from bindsnet.models import DiehlAndCook2015
from bindsnet.network.monitors import Monitor
from bindsnet.utils import get_square_assignments, get_square_weights
from bindsnet.evaluation import (
all_activity,
proportion_weighting,
assign_labels,
)
from bindsnet.evaluation import all_activity, proportion_weighting, assign_labels
from bindsnet.analysis.plotting import (
plot_input,
plot_assignments,
Expand Down Expand Up @@ -183,11 +179,7 @@

print(
"\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
% (
accuracy["all"][-1],
np.mean(accuracy["all"]),
np.max(accuracy["all"]),
)
% (accuracy["all"][-1], np.mean(accuracy["all"]), np.max(accuracy["all"]))
)
print(
"Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f (best)\n"
Expand Down Expand Up @@ -233,11 +225,7 @@
voltages = {"Ae": exc_voltages, "Ai": inh_voltages}

inpt_axes, inpt_ims = plot_input(
image.sum(1).view(28, 28),
inpt,
label=label,
axes=inpt_axes,
ims=inpt_ims,
image.sum(1).view(28, 28), inpt, label=label, axes=inpt_axes, ims=inpt_ims
)
spike_ims, spike_axes = plot_spikes(
{layer: spikes[layer].get("s").view(time, 1, -1) for layer in spikes},
Expand Down