forked from marian42/butterflies
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_autoencoder.py
36 lines (25 loc) · 1009 Bytes
/
test_autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from itertools import count
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import utils
from tqdm import tqdm
from autoencoder import Autoencoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
AUTOENCODER_FILENAME = 'trained_models/autoencoder.to'
from image_loader import ImageDataset
dataset = ImageDataset(return_hashes=True)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4)
autoencoder = Autoencoder()
autoencoder.load_state_dict(torch.load(AUTOENCODER_FILENAME))
autoencoder.eval()
with torch.no_grad():
for sample in tqdm(data_loader):
image, hash = sample
hash = hash[0]
output = autoencoder.decode(autoencoder.encode(image.to(device).unsqueeze(0)))
result = torch.zeros((3, 128, 256))
result[:, :, :128] = image.cpu()
result[:, :, 128:] = output
utils.save_image(result, 'data/test/{:s}.jpg'.format(hash))