@@ -92,21 +92,36 @@ classification task of your choice.)
92
92
)
93
93
return model
94
94
95
- def get_dataloader (batch_size = 256 , num_workers = 8 , split = ' train' ):
96
-
97
- transforms = torchvision.transforms.Compose(
98
- [torchvision.transforms.RandomHorizontalFlip(),
99
- torchvision.transforms.RandomAffine(0 ),
100
- torchvision.transforms.ToTensor(),
101
- torchvision.transforms.Normalize((0.4914 , 0.4822 , 0.4465 ), (0.2023 , 0.1994 , 0.201 ))])
102
-
95
+ def get_dataloader (batch_size = 256 , num_workers = 8 , split = ' train' , shuffle = False , augment = True ):
96
+ if augment:
97
+ transforms = torchvision.transforms.Compose(
98
+ [torchvision.transforms.RandomHorizontalFlip(),
99
+ torchvision.transforms.RandomAffine(0 ),
100
+ torchvision.transforms.ToTensor(),
101
+ torchvision.transforms.Normalize((0.4914 , 0.4822 , 0.4465 ),
102
+ (0.2023 , 0.1994 , 0.201 ))])
103
+ else :
104
+ transforms = torchvision.transforms.Compose([
105
+ torchvision.transforms.ToTensor(),
106
+ torchvision.transforms.Normalize((0.4914 , 0.4822 , 0.4465 ),
107
+ (0.2023 , 0.1994 , 0.201 ))])
108
+
103
109
is_train = (split == ' train' )
104
- dataset = torchvision.datasets.CIFAR10(root = ' /tmp/cifar/' , download = True , train = is_train, transform = transforms)
105
- loader = torch.utils.data.DataLoader(dataset = dataset, shuffle = False , batch_size = batch_size, num_workers = num_workers)
106
-
110
+ dataset = torchvision.datasets.CIFAR10(root = ' /tmp/cifar/' ,
111
+ download = True ,
112
+ train = is_train,
113
+ transform = transforms)
114
+
115
+ loader = torch.utils.data.DataLoader(dataset = dataset,
116
+ shuffle = shuffle,
117
+ batch_size = batch_size,
118
+ num_workers = num_workers)
119
+
107
120
return loader
108
121
109
- def train (model , loader , lr = 0.4 , epochs = 24 , momentum = 0.9 , weight_decay = 5e-4 , lr_peak_epoch = 5 , label_smoothing = 0.0 ):
122
+ def train (model , loader , lr = 0.4 , epochs = 24 , momentum = 0.9 ,
123
+ weight_decay = 5e-4 , lr_peak_epoch = 5 , label_smoothing = 0.0 , model_id = 0 ):
124
+
110
125
opt = SGD(model.parameters(), lr = lr, momentum = momentum, weight_decay = weight_decay)
111
126
iters_per_epoch = len (loader)
112
127
# Cyclic LR with single triangle
@@ -118,9 +133,8 @@ classification task of your choice.)
118
133
loss_fn = CrossEntropyLoss(label_smoothing = label_smoothing)
119
134
120
135
for ep in range (epochs):
121
- model_count = 0
122
136
for it, (ims, labs) in enumerate (loader):
123
- ims = ims.float(). cuda()
137
+ ims = ims.cuda()
124
138
labs = labs.cuda()
125
139
opt.zero_grad(set_to_none = True )
126
140
with autocast():
@@ -131,15 +145,19 @@ classification task of your choice.)
131
145
scaler.step(opt)
132
146
scaler.update()
133
147
scheduler.step()
148
+ if ep in [12 , 15 , 18 , 21 , 23 ]:
149
+ torch.save(model.state_dict(), f ' ./checkpoints/sd_ { model_id} _epoch_ { ep} .pt ' )
150
+
151
+ return model
134
152
135
153
os.makedirs(' ./checkpoints' , exist_ok = True )
154
+ loader_for_training = get_dataloader(batch_size = 512 , split = ' train' , shuffle = True )
136
155
137
- for i in tqdm(range (3 ), desc = ' Training models..' ):
156
+ # you can modify the for loop below to train more models
157
+ for i in tqdm(range (1 ), desc = ' Training models..' ):
138
158
model = construct_rn9().to(memory_format = torch.channels_last).cuda()
139
- loader_train = get_dataloader(batch_size = 512 , split = ' train' )
140
- train(model, loader_train)
159
+ model = train(model, loader_for_training, model_id = i)
141
160
142
- torch.save(model.state_dict(), f ' ./checkpoints/sd_ { i} .pt ' )
143
161
144
162
.. raw :: html
145
163
@@ -311,4 +329,4 @@ The final line above returns :code:`TRAK` scores as a :code:`numpy.array` from t
311
329
312
330
That's it!
313
331
Once you have your model(s) and your data, just a few API-calls to TRAK
314
- let's you compute data attribution scores.
332
+ let's you compute data attribution scores.
0 commit comments