-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
151 lines (118 loc) · 4.39 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from collections import OrderedDict
import logging
import torch
import torch.distributed as dist
from PIL import Image
from pytorch_gan_metrics import get_inception_score_and_fid
def create_logger(logging_dir=None):
"""
Create a logger that writes to a log file and stdout.
"""
if dist.get_rank() == 0: # real logger
logging.basicConfig(
level=logging.INFO,
format='[\033[34m%(asctime)s\033[0m] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
)
logger = logging.getLogger(__name__)
else: # dummy logger (does nothing)
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
return logger
@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
'''
Step the EMA model towards the current model.
'''
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
def requires_grad(model, flag=True):
'''
Set requires_grad flag for all parameters in a model.
'''
for p in model.parameters():
p.requires_grad = flag
def save_ckpt(args, model, ema, opt, checkpoint_path):
'''
Save a checkpoint containing the online model, EMA, and optimizer states.
'''
checkpoint = {
'args': args,
'model': model.module.state_dict(),
'ema': ema.state_dict(),
'opt': opt.state_dict(),
}
torch.save(checkpoint, checkpoint_path)
def sample_image(args, model, device, image_path, set_train=False, cond=False):
'''
sample a batch of images for visualization.
set set_train to true if you are using the online model for sampling.
'''
model.eval()
n_row = 16
size = args.input_size
z = torch.randn(n_row*n_row, 3, size, size).to(device)
c = torch.randint(0, args.num_classes, (n_row*n_row,)).to(device) if cond else None
with torch.no_grad():
x = model(z, c)
x = x.view(n_row, n_row, 3, size, size)
x = (x * 127.5 + 128).clip(0, 255).to(torch.uint8)
images = x.permute(0, 3, 1, 4, 2).reshape(n_row*size, n_row*size, 3).cpu().numpy()
Image.fromarray(images, 'RGB').save(image_path)
del images, x, z, c
torch.cuda.empty_cache()
if set_train:
model.train()
def num_to_groups(num, divisor):
'''
Compute number of samples in each batch to evenly divide the total eval samples.
'''
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def sample_fid(args, model, device, rank, set_train=False, cond=False):
'''
Sample args.eval_samples images in parallel for FID and IS calculation. Default 50k images.
Set set_train to True if you are using the online model for sampling.
'''
# Setup batches for each node
assert args.eval_samples % dist.get_world_size() == 0
samples_per_node = args.eval_samples // dist.get_world_size()
batches = num_to_groups(samples_per_node, args.eval_batch_size)
# Dist EMA/online evaluation
# No need to use the DDP wrapper here
# As we do not need grad sycn (by DDP)
model.eval()
model = model.to(device)
n_cls = args.num_classes
size = args.input_size
images = []
with torch.no_grad():
for n in batches:
z = torch.randn(n, 3, size, size).to(device)
c = torch.randint(0, n_cls, (n,)).to(device) if cond else None
x = model(z, c)
images.append(x)
images = torch.cat(images, dim=0)
torch.cuda.empty_cache()
if set_train:
model.train()
return images
def compute_fid_is(args, all_images, rank):
'''
Compute FID and IS using provided images.
'''
# Post-process to images.
all_images = torch.cat(all_images, dim=0)
all_images = (all_images * 127.5 + 128).clip(0, 255).to(torch.uint8).float().div(255).cpu()
# Compute FID & IS
(IS, IS_std), FID = get_inception_score_and_fid(all_images, args.stat_path)
torch.cuda.empty_cache()
return FID, IS