Skip to content

Commit e8b8118

Browse files
authoredJul 28, 2023
Add files via upload
1 parent 6ab3e57 commit e8b8118

File tree

2 files changed

+760
-0
lines changed

2 files changed

+760
-0
lines changed
 

‎FreeDrag_gradio.py

+309
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
import gradio as gr
2+
import torch
3+
import numpy as np
4+
from functions import to_image, draw_handle_target_points, free_drag, image_inversion, add_watermark_np
5+
import dnnlib
6+
from training import networks
7+
import legacy
8+
import cv2
9+
10+
# export CUDA_LAUNCH_BLOCKING=1
11+
def load_model(model_name, device):
12+
13+
path = './checkpoints/' + str(model_name)
14+
with dnnlib.util.open_url(path) as f:
15+
G = legacy.load_network_pkl(f)['G_ema'].to(device)
16+
G_copy = networks.Generator(z_dim=G.z_dim, c_dim= G.c_dim, w_dim =G.w_dim,
17+
img_resolution = G.img_resolution,
18+
img_channels = G.img_channels,
19+
mapping_kwargs = G.init_kwargs['mapping_kwargs'])
20+
21+
G_copy.load_state_dict(G.state_dict())
22+
G_copy.to(device)
23+
del(G)
24+
for param in G_copy.parameters():
25+
param.requires_grad = False
26+
return G_copy, model_name
27+
28+
def draw_mask(image,mask):
29+
30+
image_mask = image*(1-mask) +mask*(0.7*image+0.3*255.0)
31+
32+
return image_mask
33+
34+
35+
class ModelWrapper:
36+
def __init__(self, model,model_name):
37+
self.g = model
38+
self.name = model_name
39+
self.res = CKPT_SIZE[model_name][0]
40+
self.l = CKPT_SIZE[model_name][1]
41+
self.d = CKPT_SIZE[model_name][2]
42+
43+
44+
# model, points, mask, feature_size, train_layer_index,max_step, device,seed=2023,max_distance=3, d=0.5
45+
# img_show, current_target, step_number
46+
def on_drag(model, points, mask, max_iters,latent,sample_interval,l_expected,d_max,save_video,add_watermark):
47+
48+
if len(points['handle']) == 0:
49+
raise gr.Error('You must select at least one handle point and target point.')
50+
if len(points['handle']) != len(points['target']):
51+
raise gr.Error('You have uncompleted handle points, try to selct a target point or undo the handle point.')
52+
max_iters = int(max_iters)
53+
54+
handle_size = 128
55+
train_layer_index=6
56+
l_expected = torch.tensor(l_expected,device=latent.device)
57+
d_max = torch.tensor(d_max,device=latent.device)
58+
mask[mask>0] = 1
59+
global stop_flag
60+
stop_flag = False
61+
images_total = []
62+
for img_show, current_target, step_number,full_res, latent_optimized in free_drag(model.g,points,mask[:,:,0],handle_size, \
63+
train_layer_index,latent,max_iters,l_expected,d_max,sample_interval,device=latent.device):
64+
image = to_image(img_show)
65+
66+
points['handle'] = [current_target[p,:].cpu().numpy().astype('int') for p in range(len(current_target[:,0]))]
67+
image_show = add_points_to_image(image, points, size=RES_TO_CLICK_SIZE[full_res],color="yellow")
68+
69+
if np.any(mask[:,:,0]>0):
70+
image_show = draw_mask(image_show,mask)
71+
image_show = np.uint8(image_show)
72+
73+
if add_watermark:
74+
image_show = add_watermark_np(np.array(image_show))[:,:,[0,1,2]]
75+
image_clear = add_watermark_np(np.array(image))[:,:,[0,1,2]]
76+
else:
77+
image_clear = image
78+
79+
if save_video:
80+
images_total.append(image_show)
81+
yield (image_show, step_number, latent_optimized,image_clear,images_total,gr.Button.update(interactive=True))
82+
83+
if stop_flag:
84+
break
85+
86+
def add_points_to_image(image, points, size=5,color="red"):
87+
image = draw_handle_target_points(image, points['handle'], points['target'], size, color)
88+
return image
89+
90+
def on_show_save():
91+
return gr.update(visible=True)
92+
93+
def on_click(image, target_point, points, res, evt: gr.SelectData):
94+
if target_point:
95+
points['target'].append([evt.index[1], evt.index[0]])
96+
image = add_points_to_image(image, points, size=RES_TO_CLICK_SIZE[res])
97+
return image, not target_point
98+
points['handle'].append([evt.index[1], evt.index[0]])
99+
image = add_points_to_image(image, points, size=RES_TO_CLICK_SIZE[res])
100+
return image, not target_point
101+
102+
def new_image(model,seed=-1):
103+
if seed == -1:
104+
seed = np.random.randint(1,1e6)
105+
z1 = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.g.z_dim)).to(device)
106+
label = torch.zeros([1, model.g.c_dim], device=device)
107+
ws_original= model.g.get_ws(z1,label,truncation_psi=0.7)
108+
_, img_show_original = model.g.synthesis(ws=ws_original,noise_mode='const')
109+
110+
return to_image(img_show_original), to_image(img_show_original), ws_original, seed
111+
112+
def new_model(model_name):
113+
model_load, _ = load_model(model_name, device)
114+
model = ModelWrapper(model_load,model_name)
115+
116+
return model, model.res, model.l, model.d
117+
118+
def reset_all(image,mask,add_watermark=0):
119+
points = {'target': [], 'handle': []}
120+
target_point = False
121+
mask = np.zeros_like(mask,dtype=np.uint8)
122+
123+
return points, target_point, image, None,mask, add_watermark
124+
125+
def add_mask(image_show,mask):
126+
image_show = draw_mask(image_show,mask)
127+
return image_show
128+
129+
def update_mask(image,mask_show):
130+
mask = np.zeros_like(image)
131+
if mask_show != None and np.any(mask_show['mask'][:,:,0]>1):
132+
mask[mask_show['mask'][:,:,:3]>0] =1
133+
image_mask = add_mask(image,mask)
134+
return np.uint8(image_mask), mask
135+
else:
136+
return image, mask
137+
138+
def on_select_mask_tab(image):
139+
return image
140+
141+
def change_stop_state():
142+
global stop_flag
143+
stop_flag = True
144+
145+
def save_video(imgs_show_list,frame):
146+
if len(imgs_show_list)>0:
147+
video_name = './process.mp4'
148+
fource = cv2.VideoWriter_fourcc(*'mp4v')
149+
full_res = imgs_show_list[0].shape[0]
150+
video_output = cv2.VideoWriter(video_name,fourcc=fource,fps=frame,frameSize = (full_res,full_res))
151+
for k in range(len(imgs_show_list)):
152+
video_output.write(imgs_show_list[k][:,:,::-1])
153+
video_output.release()
154+
return []
155+
156+
CKPT_SIZE = {
157+
'faces.pkl':[512, 0.3, 3],
158+
'horses.pkl': [256, 0.3, 3],
159+
'elephants.pkl': [512, 0.4, 4],
160+
'lions.pkl':[512, 0.4, 4],
161+
'dogs.pkl':[1024, 0.4, 4],
162+
'bicycles.pkl':[256, 0.3, 3],
163+
'giraffes.pkl':[512, 0.4, 4],
164+
'cats.pkl':[512, 0.3, 3],
165+
'cars.pkl':[512, 0.3, 3],
166+
'churches.pkl':[256, 0.3, 3],
167+
'metfaces.pkl':[1024, 0.3, 3],
168+
}
169+
RES_TO_CLICK_SIZE = {
170+
1024: 10,
171+
512: 5,
172+
256: 3,
173+
}
174+
175+
if torch.cuda.is_available():
176+
device = 'cuda'
177+
else:
178+
device = 'cpu'
179+
180+
demo = gr.Blocks()
181+
182+
with demo:
183+
184+
points = gr.State({'target': [], 'handle': []})
185+
target_point = gr.State(False)
186+
state = gr.State({})
187+
188+
gr.Markdown(
189+
"""
190+
# **FreeDrag**
191+
192+
Official implementation of [FreeDrag: Point Tracking is Not You Need for Interactive Point-based Image Editing](https://github.com/LPengYang/FreeDrag)
193+
194+
195+
## Parameter Description
196+
**max_step**: max number of optimization step
197+
198+
**sample_interval**: the interval between sampled optimization step.
199+
This parameter only affects the visualization of intermediate results and does not have any impact on the final outcome.
200+
For high-resolution images(such as model of dog), a larger sample_interval can significantly accelerate the dragging process.
201+
202+
**Eepected initial loss and Max distance**: In the current version, both of these values are empirically set for each model.
203+
Generally, for precise editing needs (e.g., merging eyes), smaller values are recommended, which may causes longer processing times.
204+
Users can set these values according to practical editing requirements. We are currently seeking an automated solution.
205+
206+
**frame_rate**: the frame rate for saved video.
207+
208+
## Hints
209+
- Handle points (Blue): the point you want to drag.
210+
- Target points (Red): the destination you want to drag towards to.
211+
- **Localized points (Yellow)**: the localized points in sub-motion
212+
""",
213+
)
214+
215+
with gr.Row():
216+
with gr.Column(scale=0.4):
217+
with gr.Accordion("Model"):
218+
with gr.Row():
219+
with gr.Column(min_width=100):
220+
seed = gr.Number(label='Seed',value=0)
221+
with gr.Column(min_width=100):
222+
button_new = gr.Button('Image Generate', variant='primary')
223+
button_rand = gr.Button('Rand Generate')
224+
model_name = gr.Dropdown(label="Model name",choices=list(CKPT_SIZE.keys()),value = list(CKPT_SIZE.keys())[0])
225+
226+
with gr.Accordion('Optional Parameters'):
227+
with gr.Row():
228+
with gr.Column(min_width=100):
229+
max_step = gr.Number(label='Max step',value=2000)
230+
with gr.Column(min_width=100):
231+
sample_interval = gr.Number(label='Interval',value=5,info="Sampling interval")
232+
233+
model_load, _ = load_model(model_name.value, device)
234+
model = gr.State(ModelWrapper(model_load,model_name.value))
235+
l_expected = gr.Slider(0.1,0.5,label='Eepected initial loss for each sub-motion',value = model.value.l,step=0.05)
236+
d_max= gr.Slider(1.0,6.0,label='Max distance for each sub-motion (in the feature map)',value = model.value.d,step=0.5)
237+
238+
res = gr.State(model.value.res)
239+
z1 = torch.from_numpy(np.random.RandomState(int(seed.value)).randn(1, model.value.g.z_dim)).to(device)
240+
label = torch.zeros([1, model.value.g.c_dim], device=device)
241+
ws_original= model.value.g.get_ws(z1,label,truncation_psi=0.7)
242+
latent = gr.State(ws_original)
243+
add_watermark = gr.State(torch.zeros(1,device=device))
244+
245+
_, img_show_original = model.value.g.synthesis(ws=ws_original,noise_mode='const')
246+
247+
with gr.Accordion('Video'):
248+
images_total = gr.State([])
249+
with gr.Row():
250+
with gr.Column(min_width=100):
251+
if_save_video = gr.Radio(["True","False"],value="False",label="if save video")
252+
with gr.Column(min_width=100):
253+
frame_rate = gr.Number(label="Frame rate",value=5)
254+
with gr.Row():
255+
with gr.Column(min_width=100):
256+
button_video = gr.Button('Save video', variant='primary')
257+
258+
with gr.Accordion('Drag'):
259+
260+
with gr.Row():
261+
with gr.Column(min_width=200):
262+
reset_btn = gr.Button('Reset points and mask')
263+
with gr.Row():
264+
with gr.Column(min_width=100):
265+
button_drag = gr.Button('Drag it', variant='primary')
266+
with gr.Column(min_width=100):
267+
button_stop = gr.Button('Stop')
268+
269+
progress = gr.Number(value=0, label='Steps', interactive=False)
270+
271+
with gr.Column(scale=0.53):
272+
with gr.Tabs() as Tabs:
273+
image_show = to_image(img_show_original)
274+
image_clear = gr.State(image_show)
275+
mask = gr.State(np.zeros_like(image_clear.value))
276+
with gr.Tab('Setup Handle Points', id='input') as imagetab:
277+
image = gr.Image(image_show).style(height=768, width=768)
278+
with gr.Tab('Draw a Mask', id='mask') as masktab:
279+
mask_show = gr.ImageMask(image_show).style(height=768, width=768)
280+
281+
image.select(on_click, [image, target_point, points, res], [image, target_point]).then(on_show_save)
282+
283+
image.upload(image_inversion,[image,res,model],[latent,image,image_clear,add_watermark]).then(reset_all,
284+
inputs=[image_clear,mask,add_watermark],outputs=[points,target_point,image,mask_show,mask,add_watermark])
285+
286+
button_drag.click(on_drag, inputs=[model, points, mask, max_step,latent,sample_interval,l_expected,d_max,if_save_video,add_watermark], \
287+
outputs=[image, progress, latent, image_clear, images_total, button_stop])
288+
button_stop.click(change_stop_state)
289+
290+
button_video.click(save_video,inputs=[images_total,frame_rate],outputs=[images_total])
291+
reset_btn.click(reset_all,inputs=[image_clear,mask,add_watermark],outputs= [points,target_point,image,mask_show,mask,add_watermark]).then(on_show_save)
292+
293+
button_new.click(new_image, inputs = [model,seed],outputs = [image, image_clear, latent,seed]).then(reset_all,
294+
inputs=[image_clear,mask],outputs=[points,target_point,image,mask_show,mask,add_watermark])
295+
296+
button_rand.click(new_image, inputs = [model],outputs = [image, image_clear, latent,seed]).then(reset_all,
297+
inputs=[image_clear,mask],outputs=[points,target_point,image,mask_show,mask,add_watermark])
298+
299+
model_name.change(new_model,inputs=[model_name],outputs=[model,res,l_expected,d_max]).then \
300+
(new_image, inputs = [model,seed],outputs = [image, image_clear, latent,seed]).then \
301+
(reset_all,inputs=[image_clear,mask],outputs=[points,target_point,image,mask_show,mask,add_watermark])
302+
303+
imagetab.select(update_mask,[image,mask_show],[image,mask])
304+
masktab.select(on_select_mask_tab, inputs=[image], outputs=[mask_show])
305+
306+
307+
if __name__ == "__main__":
308+
309+
demo.queue(concurrency_count=3,max_size=20).launch(share=True)

0 commit comments

Comments
 (0)
Please sign in to comment.