Skip to content

Commit 82ce2bf

Browse files
author
iurada
committed
Updated repository
1 parent 2c446da commit 82ce2bf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+6350
-0
lines changed

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
__pycache__/
2+
env/
3+
record/
4+
wandb/
5+
.DS_Store

README.md

+58
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,60 @@
11
# Finding Lottery Tickets in Vision Models via Data-driven Spectral Foresight Pruning [CVPR 2024]
22
Official code of our work "Finding Lottery Tickets in Vision Models via Data-driven Spectral Foresight Pruning" accepted at CVPR 2024.
3+
4+
<p align="center"><img width="50%" src="./assets/teaser.png"></p>
5+
6+
## Introduction
7+
<i>Recent advances in neural network pruning have shown how it is possible to reduce the computational costs and memory demands of deep learning models before training. We focus on this framework and propose a new pruning at initialization algorithm that leverages the Neural Tangent Kernel (NTK) theory to align the training dynamics of the sparse network with that of the dense one. Specifically, we show how the usually neglected data-dependent component in the NTK's spectrum can be taken into account by providing an analytical upper bound to the NTK's trace obtained by decomposing neural networks into individual paths. This leads to our Path eXclusion (PX), a foresight pruning method designed to preserve the parameters that mostly influence the NTK's trace. PX is able to find lottery tickets (i.e. good paths) even at high sparsity levels and largely reduces the need for additional training. When applied to pre-trained models it extracts subnetworks directly usable for several downstream tasks, resulting in performance comparable to those of the dense counterpart but with substantial cost and computational savings.</i>
8+
9+
# Setting up the environment
10+
### Requirements
11+
Make sure to have a CUDA capable device, running at learst CUDA 11.7. Throughout our experiments we used Python version 3.10.9
12+
13+
### General Dependencies
14+
To install all the required dependencies go to the root folder of this project and run:
15+
```bash
16+
pip install -r requirements.txt
17+
```
18+
19+
### Datasets
20+
1. CIFAR-10, CIFAR-100 and Tiny-ImageNet datasets will be downloaded automatically to the folder specified by the `CONFIG.dataset_args['root']` argument once you run the experiments.
21+
2. ImageNet needs to be downloaded from the official website `https://www.image-net.org/`.
22+
3. Pascal VOC2012 can be downloaded by running the following script:
23+
```python
24+
import os
25+
import tarfile
26+
from torchvision.datasets.utils import download_url
27+
28+
db = {'2012': {
29+
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
30+
'filename': 'VOCtrainval_11-May-2012.tar',
31+
'md5': '6cd6e144f989b92b3379bac3b3de84fd',
32+
'base_dir': 'VOCdevkit/VOC2012'
33+
}}
34+
35+
def download_extract(url, root, filename, md5):
36+
download_url(url, root, filename, md5)
37+
with tarfile.open(os.path.join(root, filename), "r") as tar:
38+
tar.extractall(path=root)
39+
40+
if __name__ == '__main__':
41+
download_extract(db['2012']['url'], 'data/VOC2012', db['2012']['filename'], db['2012']['md5'])
42+
```
43+
44+
At this point you should be able to run the provided code.
45+
46+
## Running The Experiments
47+
Please refer to the `parse_args.py` file for the full list of command line arguments available. You can find in the `launch_scripts/` folder some example scripts used to run the experiments.
48+
49+
## Acknowledgement
50+
Our code is developed starting from the [Synflow](https://arxiv.org/abs/2006.05467) code repository: https://github.com/ganguli-lab/Synaptic-Flow.
51+
52+
# Citation
53+
```
54+
@inproceedings{iurada2024finding,
55+
author={Iurada, Leonardo and Ciccone, Marco and Tommasi, Tatiana},
56+
booktitle={CVPR},
57+
title={Finding Lottery Tickets in Vision Models via Data-driven Spectral Foresight Pruning},
58+
year={2024}
59+
}
60+
```

assets/teaser.png

659 KB
Loading

data/.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*
2+
!.gitignore

datasets/CIFAR10/dataset.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from datasets.utils import SeededDataLoader
2+
import torchvision.transforms as T
3+
from torchvision.datasets import CIFAR10
4+
from globals import CONFIG
5+
6+
def get_transform(size, padding, mean, std, preprocess):
7+
transform = []
8+
transform.append(T.Resize((size, size)))
9+
if preprocess:
10+
transform.append(T.RandomCrop(size=size, padding=padding))
11+
transform.append(T.RandomHorizontalFlip())
12+
transform.append(T.ToTensor())
13+
transform.append(T.Normalize(mean, std))
14+
return T.Compose(transform)
15+
16+
def load_data():
17+
size = 32
18+
if 'pretrain' in CONFIG.experiment_args:
19+
size = 224
20+
21+
CONFIG.num_classes = 10
22+
CONFIG.data_input_size = (3, size, size)
23+
24+
mean, std = (0.491, 0.482, 0.447), (0.247, 0.243, 0.262)
25+
train_transform = get_transform(size=size, padding=4, mean=mean, std=std, preprocess=True)
26+
test_transform = get_transform(size=size, padding=4, mean=mean, std=std, preprocess=False)
27+
train_dataset = CIFAR10(CONFIG.dataset_args['root'], train=True, download=True, transform=train_transform)
28+
test_dataset = CIFAR10(CONFIG.dataset_args['root'], train=False, download=True, transform=test_transform)
29+
30+
train_loader = SeededDataLoader(
31+
train_dataset,
32+
batch_size=CONFIG.batch_size,
33+
shuffle=True,
34+
num_workers=CONFIG.num_workers,
35+
pin_memory=True,
36+
persistent_workers=True
37+
)
38+
39+
test_loader = SeededDataLoader(
40+
test_dataset,
41+
batch_size=CONFIG.batch_size,
42+
shuffle=False,
43+
num_workers=CONFIG.num_workers,
44+
pin_memory=True,
45+
persistent_workers=True
46+
)
47+
48+
return {'train': train_loader, 'test': test_loader}

datasets/CIFAR100/dataset.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from datasets.utils import SeededDataLoader
2+
import torchvision.transforms as T
3+
from torchvision.datasets import CIFAR100
4+
from globals import CONFIG
5+
6+
def get_transform(size, padding, mean, std, preprocess):
7+
transform = []
8+
transform.append(T.Resize((size, size)))
9+
if preprocess:
10+
transform.append(T.RandomCrop(size=size, padding=padding))
11+
transform.append(T.RandomHorizontalFlip())
12+
transform.append(T.ToTensor())
13+
transform.append(T.Normalize(mean, std))
14+
return T.Compose(transform)
15+
16+
def load_data():
17+
size = 32
18+
if 'pretrain' in CONFIG.experiment_args:
19+
size = 224
20+
21+
CONFIG.num_classes = 100
22+
CONFIG.data_input_size = (3, size, size)
23+
24+
mean, std = (0.507, 0.487, 0.441), (0.267, 0.256, 0.276)
25+
train_transform = get_transform(size=size, padding=4, mean=mean, std=std, preprocess=True)
26+
test_transform = get_transform(size=size, padding=4, mean=mean, std=std, preprocess=False)
27+
train_dataset = CIFAR100(CONFIG.dataset_args['root'], train=True, download=True, transform=train_transform)
28+
test_dataset = CIFAR100(CONFIG.dataset_args['root'], train=False, download=True, transform=test_transform)
29+
30+
train_loader = SeededDataLoader(
31+
train_dataset,
32+
batch_size=CONFIG.batch_size,
33+
shuffle=True,
34+
num_workers=CONFIG.num_workers,
35+
pin_memory=True,
36+
persistent_workers=True
37+
)
38+
39+
test_loader = SeededDataLoader(
40+
test_dataset,
41+
batch_size=CONFIG.batch_size,
42+
shuffle=False,
43+
num_workers=CONFIG.num_workers,
44+
pin_memory=True,
45+
persistent_workers=True
46+
)
47+
48+
return {'train': train_loader, 'test': test_loader}

datasets/ImageNet/dataset.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import os
2+
from datasets.utils import SeededDataLoader
3+
import torchvision.transforms as T
4+
5+
from torchvision import datasets
6+
7+
from globals import CONFIG
8+
9+
def get_transform(train, mean, std):
10+
if train:
11+
transform = T.Compose([
12+
T.RandomResizedCrop(224, scale=(0.2,1.)),
13+
T.RandomGrayscale(p=0.2),
14+
T.ColorJitter(0.4, 0.4, 0.4, 0.4),
15+
T.RandomHorizontalFlip(),
16+
T.ToTensor(),
17+
T.Normalize(mean, std)])
18+
else:
19+
transform = T.Compose([
20+
T.Resize(256),
21+
T.CenterCrop(224),
22+
T.ToTensor(),
23+
T.Normalize(mean, std)])
24+
return transform
25+
26+
def load_data():
27+
CONFIG.num_classes = 1000
28+
CONFIG.data_input_size = (3, 224, 224)
29+
30+
mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
31+
train_transform = get_transform(train=True, mean=mean, std=std)
32+
test_transform = get_transform(train=False, mean=mean, std=std)
33+
34+
train_folder = os.path.join(CONFIG.dataset_args['root'], 'train')
35+
train_dataset = datasets.ImageFolder(train_folder, transform=train_transform)
36+
test_folder = os.path.join(CONFIG.dataset_args['root'], 'val')
37+
test_dataset = datasets.ImageFolder(test_folder, transform=test_transform)
38+
39+
train_loader = SeededDataLoader(
40+
train_dataset,
41+
batch_size=CONFIG.batch_size,
42+
shuffle=True,
43+
num_workers=CONFIG.num_workers,
44+
pin_memory=True,
45+
persistent_workers=True
46+
)
47+
48+
test_loader = SeededDataLoader(
49+
test_dataset,
50+
batch_size=CONFIG.batch_size,
51+
shuffle=False,
52+
num_workers=CONFIG.num_workers,
53+
pin_memory=True,
54+
persistent_workers=True
55+
)
56+
57+
return {'train': train_loader, 'test': test_loader}

datasets/TinyImageNet/dataset.py

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import io
2+
import pandas as pd
3+
import glob
4+
import os
5+
from shutil import move
6+
from os.path import join
7+
from os import listdir, rmdir
8+
from datasets.utils import SeededDataLoader
9+
import torchvision.transforms as T
10+
11+
from torchvision import datasets
12+
13+
from globals import CONFIG
14+
15+
# Based on https://github.com/tjmoon0104/pytorch-tiny-imagenet/blob/master/val_format.py
16+
def TINYIMAGENET(root, train=True, transform=None, target_transform=None, download=False):
17+
18+
def _exists(root, filename):
19+
return os.path.exists(os.path.join(root, filename))
20+
21+
def _download(url, root, filename):
22+
datasets.utils.download_and_extract_archive(url=url,
23+
download_root=root,
24+
extract_root=root,
25+
filename=filename)
26+
27+
def _setup(root, base_folder):
28+
target_folder = os.path.join(root, base_folder, 'val/')
29+
30+
val_dict = {}
31+
with open(target_folder + 'val_annotations.txt', 'r') as f:
32+
for line in f.readlines():
33+
split_line = line.split('\t')
34+
val_dict[split_line[0]] = split_line[1]
35+
36+
paths = glob.glob(target_folder + 'images/*')
37+
paths[0].split('/')[-1]
38+
for path in paths:
39+
file = path.split('/')[-1]
40+
folder = val_dict[file]
41+
if not os.path.exists(target_folder + str(folder)):
42+
os.mkdir(target_folder + str(folder))
43+
44+
for path in paths:
45+
file = path.split('/')[-1]
46+
folder = val_dict[file]
47+
dest = target_folder + str(folder) + '/' + str(file)
48+
move(path, dest)
49+
50+
os.remove(target_folder + 'val_annotations.txt')
51+
rmdir(target_folder + 'images')
52+
53+
url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
54+
filename = "tiny-imagenet-200.zip"
55+
base_folder = 'tiny-imagenet-200'
56+
57+
if download and not _exists(root, filename):
58+
_download(url, root, filename)
59+
_setup(root, base_folder)
60+
folder = os.path.join(root, base_folder, 'train' if train else 'val')
61+
62+
return datasets.ImageFolder(folder, transform=transform, target_transform=target_transform)
63+
64+
def get_transform(size, mean, std, preprocess):
65+
transform = []
66+
if preprocess:
67+
transform.append(T.RandomResizedCrop(size=size, scale=(0.1, 1.0), ratio=(0.8, 1.25)))
68+
transform.append(T.RandomHorizontalFlip())
69+
transform.append(T.ToTensor())
70+
transform.append(T.Normalize(mean, std))
71+
return T.Compose(transform)
72+
73+
def load_data():
74+
size = 64
75+
if 'pretrain' in CONFIG.experiment_args:
76+
size = 224
77+
78+
CONFIG.num_classes = 200
79+
CONFIG.data_input_size = (3, size, size)
80+
81+
mean, std = (0.480, 0.448, 0.397), (0.276, 0.269, 0.282)
82+
train_transform = get_transform(size=size, mean=mean, std=std, preprocess=True)
83+
test_transform = get_transform(size=size, mean=mean, std=std, preprocess=False)
84+
train_dataset = TINYIMAGENET(CONFIG.dataset_args['root'], train=True, download=True, transform=train_transform)
85+
test_dataset = TINYIMAGENET(CONFIG.dataset_args['root'], train=False, download=True, transform=test_transform)
86+
87+
train_loader = SeededDataLoader(
88+
train_dataset,
89+
batch_size=CONFIG.batch_size,
90+
shuffle=True,
91+
num_workers=CONFIG.num_workers,
92+
pin_memory=True,
93+
persistent_workers=True
94+
)
95+
96+
test_loader = SeededDataLoader(
97+
test_dataset,
98+
batch_size=CONFIG.batch_size,
99+
shuffle=False,
100+
num_workers=CONFIG.num_workers,
101+
pin_memory=True,
102+
persistent_workers=True
103+
)
104+
105+
return {'train': train_loader, 'test': test_loader}

0 commit comments

Comments
 (0)