Skip to content

Commit feefba8

Browse files
committed
add demo
1 parent 5277042 commit feefba8

File tree

10 files changed

+385
-3
lines changed

10 files changed

+385
-3
lines changed

README.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,10 @@ you may modify the training script to use different settings, e.g., batch size,
6161

6262
### Web APP
6363
<img src="https://s3.ax1x.com/2020/11/27/DrVLs1.png" width=300>
64-
For your convinience of visualization and evaluation, I provide an inpainting APP where you can interact with the inpainting model in a browser, to open a photo and draw area to remove. To use the web app, these additional packages are required:
6564

66-
```flask```, ```requests```, ```pillow```
65+
To use the web app, these additional packages are required:
6766

68-
Then execute the following:
67+
```flask```, ```requests```, ```pillow```
6968

7069
With GPU:
7170
```

demo.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import pdb
2+
import cv2
3+
import os
4+
from collections import OrderedDict
5+
6+
import numpy as np
7+
from werkzeug.utils import secure_filename
8+
from flask import Flask, url_for, render_template, request, redirect, send_from_directory
9+
from PIL import Image
10+
import base64
11+
import io
12+
import random
13+
14+
15+
from options.test_options import TestOptions
16+
import models
17+
import torch
18+
19+
opt = TestOptions().parse()
20+
model = models.create_model(opt)
21+
model.eval()
22+
23+
max_size = 256
24+
max_num_examples = 200
25+
UPLOAD_FOLDER = 'static/images'
26+
app = Flask(__name__)
27+
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
28+
ALLOWED_EXTENSIONS = set(['png', 'jpg', 'JPG', 'jpeg', 'bmp'])
29+
def allowed_file(filename):
30+
return '.' in filename and \
31+
filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
32+
33+
port = opt.port
34+
filelist = "./static/images/example.txt"
35+
with open(filelist, "r") as f:
36+
list_examples = f.readlines()
37+
list_examples = [n.strip("\n") for n in list_examples]
38+
39+
def process_image(img, mask, name, opt, save_to_input=True):
40+
img =img.convert("RGB")
41+
img_raw = np.array(img)
42+
w_raw, h_raw = img.size
43+
h_t, w_t = h_raw//8*8, w_raw//8*8
44+
45+
img = img.resize((w_t, h_t))
46+
img = np.array(img).transpose((2,0,1))
47+
48+
mask_raw = np.array(mask)[...,None]>0
49+
mask = mask.resize((w_t, h_t))
50+
51+
mask = np.array(mask)
52+
mask = (torch.Tensor(mask)>0).float()
53+
img = (torch.Tensor(img)).float()
54+
img = (img/255-0.5)/0.5
55+
img = img[None]
56+
mask = mask[None,None]
57+
58+
with torch.no_grad():
59+
generated,_ = model(
60+
{'image':img,'mask':mask},
61+
mode='inference')
62+
generated = torch.clamp(generated, -1, 1)
63+
generated = (generated+1)/2*255
64+
generated = generated.cpu().numpy().astype(np.uint8)
65+
generated = generated[0].transpose((1,2,0))
66+
result = generated*mask_raw+img_raw*(1-mask_raw)
67+
result = result.astype(np.uint8)
68+
69+
result = Image.fromarray(result).resize((w_raw, h_raw))
70+
result = np.array(result)
71+
result = Image.fromarray(result.astype(np.uint8))
72+
result.save(f"static/results/{name}")
73+
if save_to_input:
74+
result.save(f"static/images/{name}")
75+
76+
@app.route('/', methods=['GET', 'POST'])
77+
def hello(name=None):
78+
if 'example' in request.form:
79+
filename= request.form['example']
80+
image = Image.open(os.path.join(os.path.join(app.config['UPLOAD_FOLDER'], filename)))
81+
W, H = image.size
82+
return render_template('hello.html', name=name, image_name=filename, image_width=W,
83+
image_height=H,list_examples=list_examples)
84+
if request.method == 'POST':
85+
if 'file' in request.files:
86+
file = request.files['file']
87+
if file and allowed_file(file.filename):
88+
filename = secure_filename(file.filename)
89+
image = Image.open(file)
90+
W, H = image.size
91+
if max(W, H) > max_size:
92+
ratio = float(max_size) / max(W, H)
93+
W = int(W*ratio)
94+
H = int(H*ratio)
95+
image = image.resize((W, H))
96+
filename = "resize_"+filename
97+
image.save(os.path.join(os.path.join(app.config['UPLOAD_FOLDER'], filename)))
98+
return render_template('hello.html', name=name, image_name=filename, image_width=W,
99+
image_height=H,list_examples=list_examples)
100+
else:
101+
filename=list_examples[0]
102+
image = Image.open(os.path.join(os.path.join(app.config['UPLOAD_FOLDER'], filename)))
103+
W, H = image.size
104+
return render_template('hello.html', name=name, image_name=filename, image_width=W, image_height=H,
105+
is_alert=True,list_examples=list_examples)
106+
if 'mask' in request.form:
107+
filename = request.form['imgname']
108+
mask_data = request.form['mask']
109+
mask_data = mask_data.replace('data:image/png;base64,', '')
110+
mask_data = mask_data.replace(' ', '+')
111+
mask = base64.b64decode(mask_data)
112+
maskname = '.'.join(filename.split('.')[:-1]) + '.png'
113+
maskname = maskname.replace("/","_")
114+
maskname = "{}_{}".format(random.randint(0, 1000), maskname)
115+
with open(os.path.join('static/masks', maskname), "wb") as fh:
116+
fh.write(mask)
117+
mask = io.BytesIO(mask)
118+
mask = Image.open(mask).convert("L")
119+
image = Image.open(os.path.join(os.path.join(app.config['UPLOAD_FOLDER'], filename)))
120+
W, H = image.size
121+
list_op = ["result"]
122+
for op in list_op:
123+
process_image(image, mask, f"{op}_"+maskname, op, save_to_input=True)
124+
return render_template('hello.html', name=name, image_name=filename, #f"{args.opt[0]}_"+maskname
125+
mask_name=maskname, image_width=W, image_height=H, list_opt=list_op,list_examples=list_examples)
126+
else:
127+
filename=list_examples[0]
128+
image = Image.open(os.path.join(os.path.join(app.config['UPLOAD_FOLDER'], filename)))
129+
W, H = image.size
130+
return render_template('hello.html', name=name, image_name=filename, image_width=W, image_height=H,
131+
list_examples=list_examples)
132+
133+
134+
135+
if __name__ == "__main__":
136+
137+
app.run(host='0.0.0.0', debug=True, port=port, threaded=True)

demo.sh

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
python demo.py \
2+
--name objrmv \
3+
--dataset_mode testimage \
4+
--model inpaint \
5+
--netG baseconv \
6+
--which_epoch latest \
7+
--image_dir ./datasets/places2sample1k_val/places2samples1k_crop256 \
8+
--mask_dir ./datasets/places2sample1k_val/places2samples1k_256_mask_square128 \
9+
--output_dir ./results \
10+
--port 8897 \

options/test_options.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
class TestOptions(BaseOptions):
1010
def initialize(self, parser):
1111
BaseOptions.initialize(self, parser)
12+
parser.add_argument('--port', type=int, default=8897)
1213
parser.add_argument('--dataset_mode', type=str, default='coco')
1314
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
1415
parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')

static/images/example.txt

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
images/Places365_val_00020514.png
2+
images/Places365_val_00016309.png
3+
images/Places365_val_00007134.png
4+
images/Places365_val_00015508.png
5+
images/Places365_val_00017950.png
6+
images/Places365_val_00007374.png
7+
images/Places365_val_00012680.png
8+
images/Places365_val_00014413.png
9+
images/Places365_val_00032255.png
10+
images/Places365_val_00015682.png
11+
images/Places365_val_00014142.png
12+
images/Places365_val_00026857.png
13+
images/Places365_val_00004640.png
14+
images/Places365_val_00031625.png
15+
images/Places365_val_00001219.png
16+
images/Places365_val_00026150.png
17+
images/Places365_val_00005438.png
18+
images/Places365_val_00030416.png
19+
images/Places365_val_00028080.png

static/images/images

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../examples/places/images

static/jquery.min.js

+4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

static/masks/.gitkeep

Whitespace-only changes.

static/results/.gitkeep

Whitespace-only changes.

0 commit comments

Comments
 (0)