-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathmain-pixelsnail.py
107 lines (86 loc) · 4.57 KB
/
main-pixelsnail.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
import argparse
import datetime
import time
from tqdm import tqdm
from pathlib import Path
from trainer import PixelTrainer
from hps import HPS_PIXEL, HPS_VQVAE
from helper import get_device, get_parameter_count
from datasets import LatentDataset
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path', type=str)
parser.add_argument('vqvae_path', type=str)
parser.add_argument('level', type=int)
parser.add_argument('--cpu', action='store_true')
parser.add_argument('--task', type=str, default='cifar10')
parser.add_argument('--load-path', type=str, default=None)
parser.add_argument('--batch-size', type=int, default=None)
parser.add_argument('--save-jpg', action='store_true')
parser.add_argument('--no-tqdm', action='store_true')
parser.add_argument('--no-save', action='store_true')
parser.add_argument('--no-amp', action='store_true') # TODO: Not implemented
parser.add_argument('--evaluate', action='store_true') # TODO: Not implemented
args = parser.parse_args()
cfg_pixel = HPS_PIXEL[args.task]
cfg_vqvae = HPS_VQVAE[args.task]
dataset_path = Path(args.dataset_path)
save_id = str(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
if args.batch_size:
cfg_pixel.mini_batch_size = args.batch_size
if not args.no_save:
runs_dir = Path(f"runs")
root_dir = runs_dir / f"pixelsnail-{args.level}-{args.task}-{save_id}"
chk_dir = root_dir / "checkpoints"
img_dir = root_dir / "images"
log_dir = root_dir / "logs"
runs_dir.mkdir(exist_ok=True)
root_dir.mkdir(exist_ok=True)
chk_dir.mkdir(exist_ok=True)
img_dir.mkdir(exist_ok=True)
log_dir.mkdir(exist_ok=True)
print("> Loading Latent dataset")
# dataset = torch.load(args.dataset_path)
# train_dataset, test_dataset = dataset['train'], dataset['test']
train_dataset, test_dataset = LatentDataset(dataset_path / "train"), LatentDataset(dataset_path / "test")
cfg_pixel.code_shape = train_dataset.get_shape(args.level)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg_pixel.mini_batch_size, num_workers=cfg_pixel.nb_workers, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=cfg_pixel.mini_batch_size, num_workers=cfg_pixel.nb_workers, shuffle=False)
print("> Initialising Model")
trainer = PixelTrainer(cfg_pixel, cfg_vqvae, args)
if args.load_path:
print(f"> Loading model parameters from checkpoint")
trainer.load_checkpoint(args.load_path)
for eid in range(cfg_pixel.max_epochs):
print(f"> Epoch {eid+1}/{cfg_pixel.max_epochs}:")
epoch_loss, epoch_accuracy = 0.0, 0.0
epoch_start_time = time.time()
pb = tqdm(train_loader, disable=args.no_tqdm)
for i, d in enumerate(pb):
x, c = d[args.level], d[args.level+1:]
loss, accuracy, _ = trainer.train(x, c)
epoch_loss += loss
epoch_accuracy += accuracy
pb.set_description(f"training loss: {epoch_loss / (i+1)} | accuracy: {100.0 * epoch_accuracy / (i+1)}%")
print(f"> Training loss: {epoch_loss / len(train_loader)} | accuracy: {100.0 * epoch_accuracy / len(train_loader)}%")
epoch_loss, epoch_accuracy = 0.0, 0.0
pb = tqdm(test_loader, disable=args.no_tqdm)
for i, d in enumerate(pb):
x, c = d[args.level], d[args.level+1:]
loss, accuracy, y = trainer.eval(x, c)
epoch_loss += loss
epoch_accuracy += accuracy
pb.set_description(f"evaluation loss: {epoch_loss / (i+1)} | accuracy: {100.0 * epoch_accuracy / (i+1)}%")
if i == 0 and not args.no_save and eid % cfg_pixel.image_frequency == 0:
img = y.argmax(dim=1).detach().cpu() / cfg_pixel.nb_entries
x = x / cfg_pixel.nb_entries
img = torch.cat([img, x], dim=0).unsqueeze(1)
save_image(img, img_dir / f"recon-{str(eid).zfill(4)}.{'jpg' if args.save_jpg else 'png'}", nrow=cfg_pixel.mini_batch_size)
print(f"> Evaluation loss: {epoch_loss / len(test_loader)} | accuracy: {100.0 * epoch_accuracy / len(test_loader)}%")
if eid % cfg_pixel.checkpoint_frequency == 0 and not args.no_save:
trainer.save_checkpoint(chk_dir / f"pixelsnail-{args.level}-{args.task}-state-dict-{str(eid).zfill(4)}.pt")
print(f"> Epoch time taken: {time.time() - epoch_start_time:.2f} seconds.")
print()