Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch 2.0 #630

Merged
merged 8 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
python-version: 3.9
- name: Install Poetry
env:
POETRY_VERSION: 1.1.13
POETRY_VERSION: 1.4.2
run: |
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/install-poetry.py | python - &&\
poetry config virtualenvs.create false
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pythonpackage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install Poetry
env:
POETRY_VERSION: 1.1.13
POETRY_VERSION: 1.4.2
run: |
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/install-poetry.py | python - &&\
poetry config virtualenvs.create false
Expand Down
3 changes: 0 additions & 3 deletions bindsnet/datasets/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import collections

import torch
from torch._six import string_classes
from torch.utils.data._utils import collate as pytorch_collate


Expand Down Expand Up @@ -75,8 +74,6 @@ def time_aware_collate(batch):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.Mapping):
return {key: time_aware_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
Expand Down
4 changes: 1 addition & 3 deletions bindsnet/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from typing import Any, Dict, Tuple

import torch
from torch._six import string_classes

from bindsnet.network import Network
from bindsnet.network.monitors import Monitor

Expand All @@ -23,7 +21,7 @@ def recursive_to(item, device):

if isinstance(item, torch.Tensor):
return item.to(device)
elif isinstance(item, (string_classes, int, float, bool)):
elif isinstance(item, (int, float, bool)):
return item
elif isinstance(item, collections.abc.Mapping):
return {key: recursive_to(item[key], device) for key in item}
Expand Down
4 changes: 2 additions & 2 deletions examples/mnist/reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@

# Run network on sample image
network.run(inputs={"I": datum}, time=time)
training_pairs.append([spikes["O"].get("s").sum(0), label])
training_pairs.append([spikes["O"].get("s"), label])

# Plot spiking activity using monitors
if plot:
Expand Down Expand Up @@ -187,7 +187,7 @@ def forward(self, x):


# Create and train logistic regression model on reservoir outputs.
model = NN(n_neurons, 10).to(device)
model = NN(n_neurons * args.time, 10).to(device)
criterion = torch.nn.MSELoss(reduction="sum")
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

Expand Down
4,049 changes: 2,133 additions & 1,916 deletions poetry.lock

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "bindsnet"
version = "0.3.1"
version = "0.3.2"
description = "Spiking neural networks for ML in Python"
authors = [ "Hananel Hazan <hananel@hazan.org.il>", "Daniel Saunders", "Darpan Sanghavi", "Hassaan Khan" ]
license = "AGPL-3.0-only"
Expand All @@ -14,13 +14,13 @@ python = ">=3.8,<3.11"
numpy = "^1.24.2"
scipy = "^1.9.1"
Cython = "^0.29.33"
torch = "1.13.1"
torchvision = "0.14.1"
torchaudio = "0.13.1"
torch = "2.0.0"
torchvision = "0.15.1"
torchaudio = "2.0.1"
tensorboardX = "2.6"
tqdm = "^4.65.0"
matplotlib = "^3.7.1"
gymnasium = "^0.27.1"
gymnasium = "^0.28.1"
scikit-build = "^0.16.7"
scikit-image = "^0.20.0"
scikit-learn = "^1.2.1"
Expand Down