Skip to content

Commit 9f73ebf

Browse files
authored
Add files via upload
1 parent 5470965 commit 9f73ebf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+5794
-0
lines changed

FreeDrag_gradio.py

+290
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
import gradio as gr
2+
import torch
3+
import numpy as np
4+
from functions import draw_handle_target_points, free_drag
5+
import dnnlib
6+
from training import networks
7+
import legacy
8+
import cv2
9+
10+
# export CUDA_LAUNCH_BLOCKING=1
11+
def load_model(model_name, device):
12+
13+
path = './checkpoints/' + str(model_name)
14+
with dnnlib.util.open_url(path) as f:
15+
G = legacy.load_network_pkl(f)['G_ema'].to(device)
16+
G_copy = networks.Generator(z_dim=G.z_dim, c_dim= G.c_dim, w_dim =G.w_dim,
17+
img_resolution = G.img_resolution,
18+
img_channels = G.img_channels,
19+
mapping_kwargs = G.init_kwargs['mapping_kwargs'])
20+
21+
G_copy.load_state_dict(G.state_dict())
22+
G_copy.to(device)
23+
del(G)
24+
for param in G_copy.parameters():
25+
param.requires_grad = False
26+
return G_copy, model_name
27+
28+
def to_image(tensor):
29+
tensor = tensor.squeeze(0).permute(1, 2, 0)
30+
arr = tensor.detach().cpu().numpy()
31+
arr = (arr - arr.min()) / (arr.max() - arr.min())
32+
arr = arr * 255
33+
return arr.astype('uint8')
34+
35+
def draw_mask(image,mask):
36+
37+
image_mask = image*(1-mask) +mask*(0.7*image+0.3*255.0)
38+
39+
return image_mask
40+
41+
42+
class ModelWrapper:
43+
def __init__(self, model,model_name):
44+
self.g = model
45+
self.name = model_name
46+
self.size = CKPT_SIZE[model_name][0]
47+
self.l = CKPT_SIZE[model_name][1]
48+
self.d = CKPT_SIZE[model_name][2]
49+
50+
51+
# model, points, mask, feature_size, train_layer_index,max_step, device,seed=2023,max_distance=3, d=0.5
52+
# img_show, current_target, step_number
53+
def on_drag(model, points, mask, max_iters,latent,sample_interval,l_expected,d_max,save_video):
54+
55+
if len(points['handle']) == 0:
56+
raise gr.Error('You must select at least one handle point and target point.')
57+
if len(points['handle']) != len(points['target']):
58+
raise gr.Error('You have uncompleted handle points, try to selct a target point or undo the handle point.')
59+
max_iters = int(max_iters)
60+
61+
handle_size = 128
62+
train_layer_index=6
63+
l_expected = torch.tensor(l_expected,device=latent.device)
64+
d_max = torch.tensor(d_max,device=latent.device)
65+
mask[mask>0] = 1
66+
67+
images_total = []
68+
for img_show, current_target, step_number,full_size, latent_optimized in free_drag(model.g,points,mask[:,:,0],handle_size, \
69+
train_layer_index,latent,max_iters,l_expected,d_max,sample_interval,device=latent.device):
70+
image = to_image(img_show)
71+
72+
points['handle'] = [current_target[p,:].cpu().numpy().astype('int') for p in range(len(current_target[:,0]))]
73+
image_show = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[full_size],color="yellow")
74+
75+
if np.any(mask[:,:,0]>0):
76+
image_show = draw_mask(image_show,mask)
77+
image_show = np.uint8(image_show)
78+
79+
if save_video:
80+
images_total.append(image_show)
81+
yield image_show, step_number, latent_optimized, image,images_total
82+
83+
def add_points_to_image(image, points, size=5,color="red"):
84+
image = draw_handle_target_points(image, points['handle'], points['target'], size, color)
85+
return image
86+
87+
def on_show_save():
88+
return gr.update(visible=True)
89+
90+
def on_click(image, target_point, points, size, evt: gr.SelectData):
91+
if target_point:
92+
points['target'].append([evt.index[1], evt.index[0]])
93+
image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
94+
return image, not target_point
95+
points['handle'].append([evt.index[1], evt.index[0]])
96+
image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
97+
return image, not target_point
98+
99+
def new_image(model,seed=-1):
100+
if seed == -1:
101+
seed = np.random.randint(1,1e6)
102+
z1 = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.g.z_dim)).to(device)
103+
label = torch.zeros([1, model.g.c_dim], device=device)
104+
ws_original= model.g.get_ws(z1,label,truncation_psi=0.7)
105+
_, img_show_original = model.g.synthesis(ws=ws_original,noise_mode='const')
106+
107+
return to_image(img_show_original), to_image(img_show_original), ws_original, seed
108+
109+
def new_model(model_name):
110+
111+
model_load, _ = load_model(model_name, device)
112+
model = ModelWrapper(model_load,model_name)
113+
114+
return model, model.size, model.l, model.d
115+
116+
def reset_all(image,mask):
117+
points = {'target': [], 'handle': []}
118+
target_point = False
119+
mask = np.zeros_like(mask,dtype=np.uint8)
120+
return points, target_point, image, None,mask
121+
122+
def add_mask(image_show,mask):
123+
image_show = draw_mask(image_show,mask)
124+
return image_show
125+
126+
def update_mask(image,mask_show):
127+
mask = np.zeros_like(image)
128+
if mask_show != None and np.any(mask_show['mask'][:,:,0]>1):
129+
mask[mask_show['mask'][:,:,:3]>0] =1
130+
image_mask = add_mask(image,mask)
131+
return np.uint8(image_mask), mask
132+
else:
133+
return image, mask
134+
135+
def on_select_mask_tab(image):
136+
return image
137+
138+
def save_video(imgs_show_list,frame):
139+
if len(imgs_show_list)>0:
140+
video_name = './process.mp4'
141+
fource = cv2.VideoWriter_fourcc(*'mp4v')
142+
full_size = imgs_show_list[0].shape[0]
143+
video_output = cv2.VideoWriter(video_name,fourcc=fource,fps=frame,frameSize = (full_size,full_size))
144+
for k in range(len(imgs_show_list)):
145+
video_output.write(imgs_show_list[k][:,:,::-1])
146+
video_output.release()
147+
return []
148+
149+
CKPT_SIZE = {
150+
'faces.pkl':[512, 0.3, 3],
151+
'horses.pkl': [256, 0.3, 3],
152+
'elephants.pkl': [512, 0.4, 4],
153+
'lions.pkl':[512, 0.4, 4],
154+
'dogs.pkl':[1024, 0.4, 4],
155+
'bicycles.pkl':[256, 0.3, 3],
156+
'giraffes.pkl':[512, 0.4, 4],
157+
'cats.pkl':[512, 0.3, 3],
158+
'cars.pkl':[512, 0.3, 3],
159+
'churches.pkl':[256, 0.3, 3],
160+
'metfaces.pkl':[1024, 0.3, 3],
161+
}
162+
SIZE_TO_CLICK_SIZE = {
163+
1024: 10,
164+
512: 5,
165+
256: 3,
166+
}
167+
168+
device = 'cuda'
169+
demo = gr.Blocks()
170+
171+
with demo:
172+
173+
points = gr.State({'target': [], 'handle': []})
174+
target_point = gr.State(False)
175+
state = gr.State({})
176+
177+
gr.Markdown(
178+
"""
179+
# **FreeDrag**
180+
181+
Official implementation of [FreeDrag: Point Tracking is Not You Need for Interactive Point-based Image Editing](https://github.com/LPengYang/FreeDrag)
182+
183+
184+
## Parameter Description
185+
**max_step**: max number of optimization step
186+
187+
**sample_interval**: the interval between sampled optimization step.
188+
This parameter only affects the visualization of intermediate results and does not have any impact on the final outcome.
189+
For high-resolution images(such as model of dog), a larger sample_interval can significantly accelerate the dragging process.
190+
191+
**Eepected initial loss and Max distance**: In the current version, both of these values are empirically set for each model.
192+
Generally, for precise editing needs (e.g., merging eyes), smaller values are recommended, which may causes longer processing times.
193+
Users can set these values according to practical editing requirements. We are currently seeking an automated solution.
194+
195+
**frame_rate**: the frame rate for saved video.
196+
197+
## Hints
198+
- Handle points (Blue): the point you want to drag.
199+
- Target points (Red): the destination you want to drag towards to.
200+
- **Localized points (Yellow)**: the localized points in sub-motion
201+
""",
202+
)
203+
204+
with gr.Row():
205+
with gr.Column(scale=0.4):
206+
with gr.Accordion("Model"):
207+
with gr.Row():
208+
with gr.Column(min_width=100):
209+
seed = gr.Number(label='Seed',value=0)
210+
with gr.Column(min_width=100):
211+
button_new = gr.Button('Image Generate', variant='primary')
212+
button_rand = gr.Button('Rand Generate')
213+
model_name = gr.Dropdown(label="Model name",choices=list(CKPT_SIZE.keys()),value = list(CKPT_SIZE.keys())[0])
214+
215+
with gr.Accordion('Optional Parameters'):
216+
with gr.Row():
217+
with gr.Column(min_width=100):
218+
max_step = gr.Number(label='Max step',value=2000)
219+
with gr.Column(min_width=100):
220+
sample_interval = gr.Number(label='Interval',value=5,info="Sampling interval")
221+
222+
model_load, _ = load_model(model_name.value, device)
223+
model = gr.State(ModelWrapper(model_load,model_name.value))
224+
l_expected = gr.Slider(0.1,0.5,label='Eepected initial loss for each sub-motion',value = model.value.l,step=0.05)
225+
d_max= gr.Slider(1.0,6.0,label='Max distance for each sub-motion (in the feature map)',value = model.value.d,step=0.5)
226+
227+
size = gr.State(model.value.size)
228+
z1 = torch.from_numpy(np.random.RandomState(int(seed.value)).randn(1, model.value.g.z_dim)).to(device)
229+
label = torch.zeros([1, model.value.g.c_dim], device=device)
230+
ws_original= model.value.g.get_ws(z1,label,truncation_psi=0.7)
231+
latent = gr.State(ws_original)
232+
233+
_, img_show_original = model.value.g.synthesis(ws=ws_original,noise_mode='const')
234+
235+
with gr.Accordion('Video'):
236+
images_total = gr.State([])
237+
with gr.Row():
238+
with gr.Column(min_width=100):
239+
if_save_video = gr.Radio(["True","False"],value="False",label="if save video")
240+
with gr.Column(min_width=100):
241+
frame_rate = gr.Number(label="Frame rate",value=5)
242+
with gr.Row():
243+
with gr.Column(min_width=100):
244+
button_video = gr.Button('Save video', variant='primary')
245+
246+
with gr.Accordion('Drag'):
247+
248+
with gr.Row():
249+
with gr.Column(min_width=200):
250+
reset_btn = gr.Button('Reset points and mask')
251+
with gr.Row():
252+
button_drag = gr.Button('Drag it', variant='primary')
253+
254+
progress = gr.Number(value=0, label='Steps', interactive=False)
255+
256+
with gr.Column(scale=0.53):
257+
with gr.Tabs() as Tabs:
258+
image_show = to_image(img_show_original)
259+
image_clear = gr.State(image_show)
260+
mask = gr.State(np.zeros_like(image_clear.value))
261+
with gr.Tab('Setup Handle Points', id='input') as imagetab:
262+
image = gr.Image(image_show).style(height=768, width=768)
263+
with gr.Tab('Draw a Mask', id='mask') as masktab:
264+
mask_show = gr.ImageMask(image_show).style(height=768, width=768)
265+
266+
image.select(on_click, [image, target_point, points, size], [image, target_point]).then(on_show_save)
267+
268+
button_drag.click(on_drag, inputs=[model, points, mask, max_step,latent,sample_interval,l_expected,d_max,if_save_video], \
269+
outputs=[image, progress, latent, image_clear,images_total])
270+
271+
button_video.click(save_video,inputs=[images_total,frame_rate],outputs=[images_total])
272+
reset_btn.click(reset_all,inputs=[image_clear,mask],outputs= [points,target_point,image,mask_show,mask]).then(on_show_save)
273+
274+
button_new.click(new_image, inputs = [model,seed],outputs = [image, image_clear, latent,seed]).then(reset_all,
275+
inputs=[image_clear,mask],outputs=[points,target_point,image,mask_show,mask])
276+
277+
button_rand.click(new_image, inputs = [model],outputs = [image, image_clear, latent,seed]).then(reset_all,
278+
inputs=[image_clear,mask],outputs=[points,target_point,image,mask_show,mask])
279+
280+
model_name.change(new_model,inputs=[model_name],outputs=[model,size,l_expected,d_max]).then \
281+
(new_image, inputs = [model,seed],outputs = [image, image_clear, latent,seed]).then \
282+
(reset_all,inputs=[image_clear,mask],outputs=[points,target_point,image,mask_show,mask])
283+
284+
imagetab.select(update_mask,[image,mask_show],[image,mask])
285+
masktab.select(on_select_mask_tab, inputs=[image], outputs=[mask_show])
286+
287+
288+
if __name__ == "__main__":
289+
290+
demo.queue(concurrency_count=1,max_size=30).launch()

__init__.py

Whitespace-only changes.

__pycache__/functions.cpython-39.pyc

8.33 KB
Binary file not shown.

__pycache__/legacy.cpython-38.pyc

14.6 KB
Binary file not shown.

__pycache__/legacy.cpython-39.pyc

14.5 KB
Binary file not shown.

__pycache__/test.cpython-38.pyc

1.95 KB
Binary file not shown.

__pycache__/test.cpython-39.pyc

4.42 KB
Binary file not shown.

dnnlib/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# NVIDIA CORPORATION and its licensors retain all intellectual property
4+
# and proprietary rights in and to this software, related documentation
5+
# and any modifications thereto. Any use, reproduction, disclosure or
6+
# distribution of this software and related documentation without an express
7+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
8+
9+
from .util import EasyDict, make_cache_dir_path
226 Bytes
Binary file not shown.
226 Bytes
Binary file not shown.
13.4 KB
Binary file not shown.
13.4 KB
Binary file not shown.

0 commit comments

Comments
 (0)