-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtrain.py
104 lines (88 loc) · 4.26 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os.path
import shutil
from config import get_config
from scheduler import MipLRDecay
from loss import NeRFLoss, mse_to_psnr
from model import MipNeRF
import torch
import torch.optim as optim
import torch.utils.tensorboard as tb
from os import path
from datasets import get_dataloader, cycle
import numpy as np
from tqdm import tqdm
def train_model(config):
model_save_path = path.join(config.log_dir, "model.pt")
optimizer_save_path = path.join(config.log_dir, "optim.pt")
data = iter(cycle(get_dataloader(dataset_name=config.dataset_name, base_dir=config.base_dir, split="train", factor=config.factor, batch_size=config.batch_size, shuffle=True, device=config.device)))
eval_data = None
if config.do_eval:
eval_data = iter(cycle(get_dataloader(dataset_name=config.dataset_name, base_dir=config.base_dir, split="test", factor=config.factor, batch_size=config.batch_size, shuffle=True, device=config.device)))
model = MipNeRF(
use_viewdirs=config.use_viewdirs,
randomized=config.randomized,
ray_shape=config.ray_shape,
white_bkgd=config.white_bkgd,
num_levels=config.num_levels,
num_samples=config.num_samples,
hidden=config.hidden,
density_noise=config.density_noise,
density_bias=config.density_bias,
rgb_padding=config.rgb_padding,
resample_padding=config.resample_padding,
min_deg=config.min_deg,
max_deg=config.max_deg,
viewdirs_min_deg=config.viewdirs_min_deg,
viewdirs_max_deg=config.viewdirs_max_deg,
device=config.device,
)
optimizer = optim.AdamW(model.parameters(), lr=config.lr_init, weight_decay=config.weight_decay)
if config.continue_training:
model.load_state_dict(torch.load(model_save_path))
optimizer.load_state_dict(torch.load(optimizer_save_path))
scheduler = MipLRDecay(optimizer, lr_init=config.lr_init, lr_final=config.lr_final, max_steps=config.max_steps, lr_delay_steps=config.lr_delay_steps, lr_delay_mult=config.lr_delay_mult)
loss_func = NeRFLoss(config.coarse_weight_decay)
model.train()
os.makedirs(config.log_dir, exist_ok=True)
shutil.rmtree(path.join(config.log_dir, 'train'), ignore_errors=True)
logger = tb.SummaryWriter(path.join(config.log_dir, 'train'), flush_secs=1)
for step in tqdm(range(0, config.max_steps)):
rays, pixels = next(data)
comp_rgb, _, _ = model(rays)
pixels = pixels.to(config.device)
# Compute loss and update model weights.
loss_val, psnr = loss_func(comp_rgb, pixels, rays.lossmult.to(config.device))
optimizer.zero_grad()
loss_val.backward()
optimizer.step()
scheduler.step()
psnr = psnr.detach().cpu().numpy()
logger.add_scalar('train/loss', float(loss_val.detach().cpu().numpy()), global_step=step)
logger.add_scalar('train/coarse_psnr', float(np.mean(psnr[:-1])), global_step=step)
logger.add_scalar('train/fine_psnr', float(psnr[-1]), global_step=step)
logger.add_scalar('train/avg_psnr', float(np.mean(psnr)), global_step=step)
logger.add_scalar('train/lr', float(scheduler.get_last_lr()[-1]), global_step=step)
if step % config.save_every == 0:
if eval_data:
del rays
del pixels
psnr = eval_model(config, model, eval_data)
psnr = psnr.detach().cpu().numpy()
logger.add_scalar('eval/coarse_psnr', float(np.mean(psnr[:-1])), global_step=step)
logger.add_scalar('eval/fine_psnr', float(psnr[-1]), global_step=step)
logger.add_scalar('eval/avg_psnr', float(np.mean(psnr)), global_step=step)
torch.save(model.state_dict(), model_save_path)
torch.save(optimizer.state_dict(), optimizer_save_path)
torch.save(model.state_dict(), model_save_path)
torch.save(optimizer.state_dict(), optimizer_save_path)
def eval_model(config, model, data):
model.eval()
rays, pixels = next(data)
with torch.no_grad():
comp_rgb, _, _ = model(rays)
pixels = pixels.to(config.device)
model.train()
return torch.tensor([mse_to_psnr(torch.mean((rgb - pixels[..., :3])**2)) for rgb in comp_rgb])
if __name__ == "__main__":
config = get_config()
train_model(config)