Skip to content

Commit 2c2d022

Browse files
0.1.3 (#32)
* allow skipping model IDs in finalize scores * allow subclassing of saver and score_computer directly from traker args * default to BasicProjector if CudaProjector projeciton step errors out * add another type of error that sometime occurs when fast_jl has issues * update quickstart notebook * Add link to colab with pre-computed trak scores to readme * add dropbox links to quickstart nb * update training code in quickstart tutorial * bump version --------- Co-authored-by: Joshua Vendrow
1 parent 8995c3e commit 2c2d022

8 files changed

+395
-431
lines changed

README.md

+17-19
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
[![PyPI version](https://badge.fury.io/py/traker.svg)](https://badge.fury.io/py/traker)
21
[![arXiv](https://img.shields.io/badge/arXiv-2303.14186-b31b1b.svg?style=flat-square)](https://arxiv.org/abs/2303.14186)
2+
[![PyPI version](https://badge.fury.io/py/traker.svg)](https://badge.fury.io/py/traker)
3+
4+
# TRAK: Attributing Model Behavior at Scale
35

46
[[docs & tutorials]](https://trak.readthedocs.io/en/latest/)
5-
[[paper]](https://arxiv.org/abs/2303.14186)
67
[[blog post]](https://gradientscience.org/trak/)
78
[[website]](https://trak.csail.mit.edu)
89

9-
# TRAK: Attributing Model Behavior at Scale
10-
1110
In our [paper](https://arxiv.org/abs/2303.14186), we introduce a new data attribution method called `TRAK` (Tracing with the
1211
Randomly-Projected After Kernel). Using `TRAK`, you can make accurate
1312
counterfactual predictions (e.g., answers to questions of the form “what would
@@ -17,21 +16,10 @@ comparably effective methods, e.g., see our evaluation on:
1716

1817
![Main figure](/docs/assets/main_figure.png)
1918

20-
## Citation
21-
If you use this code in your work, please cite using the following BibTeX entry:
22-
```
23-
@inproceedings{park2023trak,
24-
title = {TRAK: Attributing Model Behavior at Scale},
25-
author = {Sung Min Park and Kristian Georgiev and Andrew Ilyas and Guillaume Leclerc and Aleksander Madry},
26-
booktitle = {Arxiv preprint arXiv:2303.14186},
27-
year = {2023}
28-
}
29-
```
30-
3119
## Usage
3220

33-
34-
[[Quickstart]](https://trak.readthedocs.io/en/latest/quickstart.html)
21+
[[quickstart]](https://trak.readthedocs.io/en/latest/quickstart.html)
22+
[[pre-computed TRAK scores for CIFAR-10]](https://colab.research.google.com/drive/1Mlpzno97qpI3UC1jpOATXEHPD-lzn9Wg?usp=sharing)
3523

3624
Check [our docs](https://trak.readthedocs.io/en/latest/) for more detailed examples and
3725
tutorials on how to use `TRAK`. Below, we provide a brief blueprint of using `TRAK`'s API to compute attribution scores.
@@ -74,6 +62,17 @@ scores = traker.finalize_scores()
7462
## Examples
7563
You can find several end-to-end examples in the `examples/` directory.
7664

65+
## Citation
66+
If you use this code in your work, please cite using the following BibTeX entry:
67+
```
68+
@inproceedings{park2023trak,
69+
title = {TRAK: Attributing Model Behavior at Scale},
70+
author = {Sung Min Park and Kristian Georgiev and Andrew Ilyas and Guillaume Leclerc and Aleksander Madry},
71+
booktitle = {Arxiv preprint arXiv:2303.14186},
72+
year = {2023}
73+
}
74+
```
75+
7776
## Installation
7877

7978
To install the version of our package which contains a fast, custom `CUDA`
@@ -93,9 +92,8 @@ pip install traker
9392

9493
Please send an email to trak@mit.edu
9594

96-
## Maintainers:
95+
## Maintainers
9796

9897
[Kristian Georgiev](https://twitter.com/kris_georgiev1)<br>
9998
[Andrew Ilyas](https://twitter.com/andrew_ilyas)<br>
100-
[Guillaume Leclerc](https://twitter.com/gpoleclerc)<br>
10199
[Sung Min Park](https://twitter.com/smsampark)

docs/source/conf.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
author = 'Kristian Georgiev'
2323

2424
# The full version, including alpha/beta/rc tags
25-
release = '0.1.2'
26-
version = '0.1.2'
25+
release = '0.1.3'
26+
version = '0.1.3'
2727

2828

2929
# -- General configuration ---------------------------------------------------

docs/source/quickstart.rst

+37-19
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,36 @@ classification task of your choice.)
9292
)
9393
return model
9494
95-
def get_dataloader(batch_size=256, num_workers=8, split='train'):
96-
97-
transforms = torchvision.transforms.Compose(
98-
[torchvision.transforms.RandomHorizontalFlip(),
99-
torchvision.transforms.RandomAffine(0),
100-
torchvision.transforms.ToTensor(),
101-
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.201))])
102-
95+
def get_dataloader(batch_size=256, num_workers=8, split='train', shuffle=False, augment=True):
96+
if augment:
97+
transforms = torchvision.transforms.Compose(
98+
[torchvision.transforms.RandomHorizontalFlip(),
99+
torchvision.transforms.RandomAffine(0),
100+
torchvision.transforms.ToTensor(),
101+
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
102+
(0.2023, 0.1994, 0.201))])
103+
else:
104+
transforms = torchvision.transforms.Compose([
105+
torchvision.transforms.ToTensor(),
106+
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
107+
(0.2023, 0.1994, 0.201))])
108+
103109
is_train = (split == 'train')
104-
dataset = torchvision.datasets.CIFAR10(root='/tmp/cifar/', download=True, train=is_train, transform=transforms)
105-
loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False, batch_size=batch_size, num_workers=num_workers)
106-
110+
dataset = torchvision.datasets.CIFAR10(root='/tmp/cifar/',
111+
download=True,
112+
train=is_train,
113+
transform=transforms)
114+
115+
loader = torch.utils.data.DataLoader(dataset=dataset,
116+
shuffle=shuffle,
117+
batch_size=batch_size,
118+
num_workers=num_workers)
119+
107120
return loader
108121
109-
def train(model, loader, lr=0.4, epochs=24, momentum=0.9, weight_decay=5e-4, lr_peak_epoch=5, label_smoothing=0.0):
122+
def train(model, loader, lr=0.4, epochs=24, momentum=0.9,
123+
weight_decay=5e-4, lr_peak_epoch=5, label_smoothing=0.0, model_id=0):
124+
110125
opt = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
111126
iters_per_epoch = len(loader)
112127
# Cyclic LR with single triangle
@@ -118,9 +133,8 @@ classification task of your choice.)
118133
loss_fn = CrossEntropyLoss(label_smoothing=label_smoothing)
119134
120135
for ep in range(epochs):
121-
model_count = 0
122136
for it, (ims, labs) in enumerate(loader):
123-
ims = ims.float().cuda()
137+
ims = ims.cuda()
124138
labs = labs.cuda()
125139
opt.zero_grad(set_to_none=True)
126140
with autocast():
@@ -131,15 +145,19 @@ classification task of your choice.)
131145
scaler.step(opt)
132146
scaler.update()
133147
scheduler.step()
148+
if ep in [12, 15, 18, 21, 23]:
149+
torch.save(model.state_dict(), f'./checkpoints/sd_{model_id}_epoch_{ep}.pt')
150+
151+
return model
134152
135153
os.makedirs('./checkpoints', exist_ok=True)
154+
loader_for_training = get_dataloader(batch_size=512, split='train', shuffle=True)
136155
137-
for i in tqdm(range(3), desc='Training models..'):
156+
# you can modify the for loop below to train more models
157+
for i in tqdm(range(1), desc='Training models..'):
138158
model = construct_rn9().to(memory_format=torch.channels_last).cuda()
139-
loader_train = get_dataloader(batch_size=512, split='train')
140-
train(model, loader_train)
159+
model = train(model, loader_for_training, model_id=i)
141160
142-
torch.save(model.state_dict(), f'./checkpoints/sd_{i}.pt')
143161
144162
.. raw:: html
145163

@@ -311,4 +329,4 @@ The final line above returns :code:`TRAK` scores as a :code:`numpy.array` from t
311329

312330
That's it!
313331
Once you have your model(s) and your data, just a few API-calls to TRAK
314-
let's you compute data attribution scores.
332+
let's you compute data attribution scores.

0 commit comments

Comments
 (0)