-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
170 lines (143 loc) · 6.34 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#!/usr/bin/env python
import os
import sys
import argparse
from argparse import RawTextHelpFormatter
import time
import matplotlib.pyplot as plt
import numpy as np
from functools import reduce
import torch
import torch.nn as nn
from torch.optim import Adam
import torch.distributed as dist
from torch.backends import cudnn
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
from dataloader import *
from models import *
from utils import *
best_prec1 = 0
configure("logs/log-1")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def main(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
config = NetworkConfig(args.config)
args.distributed = config.distributed['world_size'] > 1
if args.distributed:
print('[+] Distributed backend')
dist.init_process_group(backend=config.distributed['dist_backend'], init_method=config.distributed['dist_url'],\
world_size=config.distributed['world_size'])
# creating model instance
model = Model(config)
# plotting interactively
plt.ion()
if args.distributed:
model.to(device)
model = nn.parallel.DistributedDataParallel(model)
elif args.gpu:
model = nn.DataParallel(model).to(device)
else: return
# Data Loading
train_dataset = torchvision.datasets.MNIST(root='../../data/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='../../data/',
train=False,
transform=transforms.ToTensor())
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=config.data['batch_size'], shuffle=config.data['shuffle'],
num_workers=config.data['workers'], pin_memory=config.data['pin_memory'], sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=config.data['batch_size'], shuffle=config.data['shuffle'],
num_workers=config.data['workers'], pin_memory=config.data['pin_memory'])
# Training and Evaluation
if args.train:
trainer = Trainer('cnn', config, train_loader, model)
if args.evaluate:
evaluator = Evaluator('cnn', config, val_loader, model)
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), config.hyperparameters['lr'])
# weight_decay=config.hyperparameters['weight_decay'])
if args.train:
trainer.setCriterion(criterion)
trainer.setOptimizer(optimizer)
if args.evaluate:
evaluator.setCriterion(criterion)
# optionally resume from a checkpoint
if args.resume and not args.train:
if checkpoint is None:
path = os.path.join(self.config.checkpoints['loc'], \
self.config.checkpoints['ckpt_fname'])
else:
path = os.path.join(self.config.checkpoints['loc'], checkpoint)
torch.load(path)
start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("[#] Loaded Checkpoint '{}' (epoch {})"
.format(config.checkpoints['ckpt_fname'], checkpoint['epoch']))
if args.resume:
trainer.load_saved_checkpoint(checkpoint=None)
# Turn on benchmark if the input sizes don't vary
# It is used to find best way to run models on your machine
cudnn.benchmark = True
best_prec1 = 0
for epoch in range(config.hyperparameters['num_epochs']):
if args.distributed:
train_sampler.set_epoch(epoch)
if args.train:
trainer.adjust_learning_rate(epoch)
trainer.train(epoch)
if args.evaluate:
prec1 = evaluator.evaluate(epoch)
if args.train and args.evaluate:
# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
trainer.save_checkpoint({
'epoch': epoch+1,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer': optimizer.state_dict(),
}, is_best, checkpoint=None)
elif args.train:
is_best = True
best_prec1 = best_prec1
trainer.save_checkpoint({
'epoch': epoch+1,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer': optimizer.state_dict(),
}, is_best, checkpoint=None)
elif args.evaluate: return
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Disentangling Variations', formatter_class=RawTextHelpFormatter)
parser.add_argument('--gpu', type=int, default=0, \
help="Turn ON for GPU support; default=0")
parser.add_argument('--train', type=int, default=1, \
help="Turn ON to train model; default=1")
parser.add_argument('--resume', type=int, default=0, \
help="Turn ON to resume training from latest checkpoint; default=0")
parser.add_argument('--checkpoints', type=str, default="./checkpoints", \
help="Mention the dir that contains checkpoints")
parser.add_argument('--config', type=str, required=True, \
help="Mention the file to load required configurations of the model")
parser.add_argument('--seed', type=int, default=100, \
help="Seed for random function, default=100")
parser.add_argument('--pretrained', type=int, default=0, \
help="Turn ON if checkpoints of model available in /checkpoints dir")
parser.add_argument('--evaluate', type=int, default=0, \
help='evaluate model on validation set')
args = parser.parse_args()
main(args)