From df96c4c16404d636711415ace05e6d8017bdb94f Mon Sep 17 00:00:00 2001 From: Hananel Hazan Date: Mon, 15 Feb 2021 13:35:21 -0500 Subject: [PATCH] Fix plotting issue with batch_eth_mnist and code black formater --- examples/mnist/SOM_LM-SNNs.py | 4 +++- examples/mnist/batch_eth_mnist.py | 17 ++++++++++------- examples/mnist/eth_mnist.py | 8 ++++++-- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/mnist/SOM_LM-SNNs.py b/examples/mnist/SOM_LM-SNNs.py index 69c4ebbe..2bd6e14b 100644 --- a/examples/mnist/SOM_LM-SNNs.py +++ b/examples/mnist/SOM_LM-SNNs.py @@ -120,7 +120,9 @@ accuracy = {"all": [], "proportion": []} # Voltage recording for excitatory and inhibitory layers. -som_voltage_monitor = Monitor(network.layers["Y"], ["v"], time=int(time / dt), device=device) +som_voltage_monitor = Monitor( + network.layers["Y"], ["v"], time=int(time / dt), device=device +) network.add_monitor(som_voltage_monitor, name="som_voltage") # Set up monitors for spikes and voltages diff --git a/examples/mnist/batch_eth_mnist.py b/examples/mnist/batch_eth_mnist.py index 35199e34..6dbee43c 100644 --- a/examples/mnist/batch_eth_mnist.py +++ b/examples/mnist/batch_eth_mnist.py @@ -133,8 +133,12 @@ accuracy = {"all": [], "proportion": []} # Voltage recording for excitatory and inhibitory layers. -exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=int(time / dt), device=device) -inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=int(time / dt), device=device) +exc_voltage_monitor = Monitor( + network.layers["Ae"], ["v"], time=int(time / dt), device=device +) +inh_voltage_monitor = Monitor( + network.layers["Ai"], ["v"], time=int(time / dt), device=device +) network.add_monitor(exc_voltage_monitor, name="exc_voltage") network.add_monitor(inh_voltage_monitor, name="inh_voltage") @@ -142,16 +146,14 @@ spikes = {} for layer in set(network.layers): spikes[layer] = Monitor( - network.layers[layer], state_vars=["s"], time=int(time / dt), - device=device + network.layers[layer], state_vars=["s"], time=int(time / dt), device=device ) network.add_monitor(spikes[layer], name="%s_spikes" % layer) voltages = {} for layer in set(network.layers) - {"X"}: voltages[layer] = Monitor( - network.layers[layer], state_vars=["v"], time=int(time / dt), - device=device + network.layers[layer], state_vars=["v"], time=int(time / dt), device=device ) network.add_monitor(voltages[layer], name="%s_voltages" % layer) @@ -271,6 +273,7 @@ if plot: image = batch["image"][:, 0].view(28, 28) inpt = inputs["X"][:, 0].view(time, 784).sum(0).view(28, 28) + lable = batch["label"][0] input_exc_weights = network.connections[("X", "Ae")].w square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28 @@ -281,7 +284,7 @@ } voltages = {"Ae": exc_voltages, "Ai": inh_voltages} inpt_axes, inpt_ims = plot_input( - image, inpt, label=labels[step], axes=inpt_axes, ims=inpt_ims + image, inpt, label=lable, axes=inpt_axes, ims=inpt_ims ) spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) diff --git a/examples/mnist/eth_mnist.py b/examples/mnist/eth_mnist.py index a6372c45..e14eda0c 100644 --- a/examples/mnist/eth_mnist.py +++ b/examples/mnist/eth_mnist.py @@ -134,8 +134,12 @@ accuracy = {"all": [], "proportion": []} # Voltage recording for excitatory and inhibitory layers. -exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=int(time / dt), device=device) -inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=int(time / dt), device=device) +exc_voltage_monitor = Monitor( + network.layers["Ae"], ["v"], time=int(time / dt), device=device +) +inh_voltage_monitor = Monitor( + network.layers["Ai"], ["v"], time=int(time / dt), device=device +) network.add_monitor(exc_voltage_monitor, name="exc_voltage") network.add_monitor(inh_voltage_monitor, name="inh_voltage")