-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdemo.py
79 lines (66 loc) · 2.55 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import requests
import os
import gradio as gr
import numpy as np
import torch
import torchvision.models as models
from configs.default import get_cfg_defaults
from modeling.build import build_model
from utils.data_utils import linear_scaling
url = "https://www.dropbox.com/s/uxvax5sjx5iysyl/cifr.pth?dl=0"
r = requests.get(url, stream=True)
if not os.path.exists("cifr.pth"):
with open("cifr.pth", 'wb') as f:
for data in r:
f.write(data)
cfg = get_cfg_defaults()
cfg.MODEL.CKPT = "cifr.pth"
net, _ = build_model(cfg)
net = net.eval()
vgg16 = models.vgg16(pretrained=True).features.eval()
def load_checkpoints_from_ckpt(ckpt_path):
checkpoints = torch.load(ckpt_path, map_location=torch.device('cpu'))
net.load_state_dict(checkpoints["ifr"])
load_checkpoints_from_ckpt(cfg.MODEL.CKPT)
def filter_removal(img):
arr = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
arr = torch.tensor(arr).float() / 255.
arr = linear_scaling(arr)
with torch.no_grad():
feat = vgg16(arr)
out, _ = net(arr, feat)
out = torch.clamp(out, max=1., min=0.)
return out.squeeze(0).permute(1, 2, 0).numpy()
title = "Contrastive Instagram Filter Removal (CIFR)"
description = "This is the demo for CIFR, contrastive strategy for filter removal on fashionable images on Instagram. " \
"To use it, simply upload your filtered image, or click one of the examples to load them."
article = "<p style='text-align: center'><a href=''>Contrastive Instagram Filter Removal (CIFR)</a> | <a href='https://github.com/birdortyedi/cifr-pytorch'>Github Repo</a></p>"
gr.Interface(
filter_removal,
gr.inputs.Image(shape=(256, 256)),
gr.outputs.Image(),
title=title,
description=description,
article=article,
allow_flagging=False,
examples_per_page=17,
examples=[
["images/examples/98_He-Fe.jpg"],
["images/examples/2_Brannan.jpg"],
["images/examples/12_Toaster.jpg"],
["images/examples/18_Gingham.jpg"],
["images/examples/11_Sutro.jpg"],
["images/examples/9_Lo-Fi.jpg"],
["images/examples/3_Mayfair.jpg"],
["images/examples/4_Hudson.jpg"],
["images/examples/5_Amaro.jpg"],
["images/examples/6_1977.jpg"],
["images/examples/8_Valencia.jpg"],
["images/examples/16_Lo-Fi.jpg"],
["images/examples/10_Nashville.jpg"],
["images/examples/15_X-ProII.jpg"],
["images/examples/14_Willow.jpg"],
["images/examples/30_Perpetua.jpg"],
["images/examples/1_Clarendon.jpg"],
]
).launch()