Skip to content

Commit

Permalink
Merge pull request #3 from scil-vital/atheb/devices
Browse files Browse the repository at this point in the history
ENH: CPU inference
  • Loading branch information
AntoineTheb authored Feb 4, 2025
2 parents a9ad462 + 8daf6eb commit 8ae68d0
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 15 deletions.
52 changes: 52 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Python package

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}

- name: Install non-python dependencies
run: |
sudo apt-get update
sudo apt-get install -y \
build-essential \
curl \
git \
libblas-dev \
liblapack-dev \
libfreetype6-dev
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
export SETUPTOOLS_USE_DISTUTILS=stdlib; pip install -e .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
10 changes: 10 additions & 0 deletions LabelSeg/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch


def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ authors = [
]
description = "LabelSeg"
readme = "README.md"
requires-python = "==3.10.*"
requires-python = ">=3.10"
keywords = [""]
license = {text = "MIT"}
classifiers = [
Expand All @@ -30,7 +30,7 @@ dependencies = [
"scilpy",
'importlib-metadata; python_version<"3.10"',
]
dynamic = ["version"]
version = "0.0.1"

[project.scripts]
labelseg_train = "scripts.labelseg_train:main"
Expand Down
4 changes: 4 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

[pytest]

python_files = test*.py
9 changes: 0 additions & 9 deletions requirements.txt

This file was deleted.

6 changes: 2 additions & 4 deletions scripts/labelseg_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
from scilpy.image.volume_operations import resample_volume

from LabelSeg.models.utils import get_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cast_device = 'cuda' if torch.cuda.is_available() else 'cpu'
from LabelSeg.utils.utils import get_device

# TODO: Get bundle list from model
DEFAULT_BUNDLES = ['AF_left', 'AF_right', 'ATR_left', 'ATR_right', 'CA', 'CC_1', 'CC_2', 'CC_3', 'CC_4', 'CC_5', 'CC_6', 'CC_7', 'CG_left', 'CG_right', 'CST_left', 'CST_right', 'FPT_left', 'FPT_right', 'FX_left', 'FX_right', 'ICP_left', 'ICP_right', 'IFO_left', 'IFO_right', 'ILF_left', 'ILF_right', 'MCP', 'MLF_left', 'MLF_right', 'OR_left', 'OR_right', 'POPT_left', 'POPT_right', 'SCP_left', 'SCP_right', 'SLF_III_left', 'SLF_III_right', 'SLF_II_left', 'SLF_II_right', 'SLF_I_left', 'SLF_I_right', 'STR_left', 'STR_right', 'ST_FO_left', 'ST_FO_right', 'ST_OCC_left', 'ST_OCC_right', 'ST_PAR_left', 'ST_PAR_right', 'ST_POSTC_left', 'ST_POSTC_right', 'ST_PREC_left', 'ST_PREC_right', 'ST_PREF_left', 'ST_PREF_right', 'ST_PREM_left', 'ST_PREM_right', 'T_OCC_left', 'T_OCC_right', 'T_PAR_left', 'T_PAR_right', 'T_POSTC_left', 'T_POSTC_right', 'T_PREC_left', 'T_PREC_right', 'T_PREF_left', 'T_PREF_right', 'T_PREM_left', 'T_PREM_right', 'UF_left', 'UF_right'] # noqa E501
Expand Down Expand Up @@ -148,7 +146,7 @@ def predict(self, model, fodf, wm):
dtype=torch.float
).to('cuda:0')

prompts = torch.eye(len(self.bundles), device='cuda:0')
prompts = torch.eye(len(self.bundles), device=get_device())

with torch.no_grad():

Expand Down
2 changes: 2 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def test_dummy():
assert True

0 comments on commit 8ae68d0

Please sign in to comment.