1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from functions import to_image , draw_handle_target_points , free_drag , image_inversion , add_watermark_np
5
+ import dnnlib
6
+ from training import networks
7
+ import legacy
8
+ import cv2
9
+
10
+ # export CUDA_LAUNCH_BLOCKING=1
11
+ def load_model (model_name , device ):
12
+
13
+ path = './checkpoints/' + str (model_name )
14
+ with dnnlib .util .open_url (path ) as f :
15
+ G = legacy .load_network_pkl (f )['G_ema' ].to (device )
16
+ G_copy = networks .Generator (z_dim = G .z_dim , c_dim = G .c_dim , w_dim = G .w_dim ,
17
+ img_resolution = G .img_resolution ,
18
+ img_channels = G .img_channels ,
19
+ mapping_kwargs = G .init_kwargs ['mapping_kwargs' ])
20
+
21
+ G_copy .load_state_dict (G .state_dict ())
22
+ G_copy .to (device )
23
+ del (G )
24
+ for param in G_copy .parameters ():
25
+ param .requires_grad = False
26
+ return G_copy , model_name
27
+
28
+ def draw_mask (image ,mask ):
29
+
30
+ image_mask = image * (1 - mask ) + mask * (0.7 * image + 0.3 * 255.0 )
31
+
32
+ return image_mask
33
+
34
+
35
+ class ModelWrapper :
36
+ def __init__ (self , model ,model_name ):
37
+ self .g = model
38
+ self .name = model_name
39
+ self .res = CKPT_SIZE [model_name ][0 ]
40
+ self .l = CKPT_SIZE [model_name ][1 ]
41
+ self .d = CKPT_SIZE [model_name ][2 ]
42
+
43
+
44
+ # model, points, mask, feature_size, train_layer_index,max_step, device,seed=2023,max_distance=3, d=0.5
45
+ # img_show, current_target, step_number
46
+ def on_drag (model , points , mask , max_iters ,latent ,sample_interval ,l_expected ,d_max ,save_video ,add_watermark ):
47
+
48
+ if len (points ['handle' ]) == 0 :
49
+ raise gr .Error ('You must select at least one handle point and target point.' )
50
+ if len (points ['handle' ]) != len (points ['target' ]):
51
+ raise gr .Error ('You have uncompleted handle points, try to selct a target point or undo the handle point.' )
52
+ max_iters = int (max_iters )
53
+
54
+ handle_size = 128
55
+ train_layer_index = 6
56
+ l_expected = torch .tensor (l_expected ,device = latent .device )
57
+ d_max = torch .tensor (d_max ,device = latent .device )
58
+ mask [mask > 0 ] = 1
59
+ global stop_flag
60
+ stop_flag = False
61
+ images_total = []
62
+ for img_show , current_target , step_number ,full_res , latent_optimized in free_drag (model .g ,points ,mask [:,:,0 ],handle_size , \
63
+ train_layer_index ,latent ,max_iters ,l_expected ,d_max ,sample_interval ,device = latent .device ):
64
+ image = to_image (img_show )
65
+
66
+ points ['handle' ] = [current_target [p ,:].cpu ().numpy ().astype ('int' ) for p in range (len (current_target [:,0 ]))]
67
+ image_show = add_points_to_image (image , points , size = RES_TO_CLICK_SIZE [full_res ],color = "yellow" )
68
+
69
+ if np .any (mask [:,:,0 ]> 0 ):
70
+ image_show = draw_mask (image_show ,mask )
71
+ image_show = np .uint8 (image_show )
72
+
73
+ if add_watermark :
74
+ image_show = add_watermark_np (np .array (image_show ))[:,:,[0 ,1 ,2 ]]
75
+ image_clear = add_watermark_np (np .array (image ))[:,:,[0 ,1 ,2 ]]
76
+ else :
77
+ image_clear = image
78
+
79
+ if save_video :
80
+ images_total .append (image_show )
81
+ yield (image_show , step_number , latent_optimized ,image_clear ,images_total ,gr .Button .update (interactive = True ))
82
+
83
+ if stop_flag :
84
+ break
85
+
86
+ def add_points_to_image (image , points , size = 5 ,color = "red" ):
87
+ image = draw_handle_target_points (image , points ['handle' ], points ['target' ], size , color )
88
+ return image
89
+
90
+ def on_show_save ():
91
+ return gr .update (visible = True )
92
+
93
+ def on_click (image , target_point , points , res , evt : gr .SelectData ):
94
+ if target_point :
95
+ points ['target' ].append ([evt .index [1 ], evt .index [0 ]])
96
+ image = add_points_to_image (image , points , size = RES_TO_CLICK_SIZE [res ])
97
+ return image , not target_point
98
+ points ['handle' ].append ([evt .index [1 ], evt .index [0 ]])
99
+ image = add_points_to_image (image , points , size = RES_TO_CLICK_SIZE [res ])
100
+ return image , not target_point
101
+
102
+ def new_image (model ,seed = - 1 ):
103
+ if seed == - 1 :
104
+ seed = np .random .randint (1 ,1e6 )
105
+ z1 = torch .from_numpy (np .random .RandomState (int (seed )).randn (1 , model .g .z_dim )).to (device )
106
+ label = torch .zeros ([1 , model .g .c_dim ], device = device )
107
+ ws_original = model .g .get_ws (z1 ,label ,truncation_psi = 0.7 )
108
+ _ , img_show_original = model .g .synthesis (ws = ws_original ,noise_mode = 'const' )
109
+
110
+ return to_image (img_show_original ), to_image (img_show_original ), ws_original , seed
111
+
112
+ def new_model (model_name ):
113
+ model_load , _ = load_model (model_name , device )
114
+ model = ModelWrapper (model_load ,model_name )
115
+
116
+ return model , model .res , model .l , model .d
117
+
118
+ def reset_all (image ,mask ,add_watermark = 0 ):
119
+ points = {'target' : [], 'handle' : []}
120
+ target_point = False
121
+ mask = np .zeros_like (mask ,dtype = np .uint8 )
122
+
123
+ return points , target_point , image , None ,mask , add_watermark
124
+
125
+ def add_mask (image_show ,mask ):
126
+ image_show = draw_mask (image_show ,mask )
127
+ return image_show
128
+
129
+ def update_mask (image ,mask_show ):
130
+ mask = np .zeros_like (image )
131
+ if mask_show != None and np .any (mask_show ['mask' ][:,:,0 ]> 1 ):
132
+ mask [mask_show ['mask' ][:,:,:3 ]> 0 ] = 1
133
+ image_mask = add_mask (image ,mask )
134
+ return np .uint8 (image_mask ), mask
135
+ else :
136
+ return image , mask
137
+
138
+ def on_select_mask_tab (image ):
139
+ return image
140
+
141
+ def change_stop_state ():
142
+ global stop_flag
143
+ stop_flag = True
144
+
145
+ def save_video (imgs_show_list ,frame ):
146
+ if len (imgs_show_list )> 0 :
147
+ video_name = './process.mp4'
148
+ fource = cv2 .VideoWriter_fourcc (* 'mp4v' )
149
+ full_res = imgs_show_list [0 ].shape [0 ]
150
+ video_output = cv2 .VideoWriter (video_name ,fourcc = fource ,fps = frame ,frameSize = (full_res ,full_res ))
151
+ for k in range (len (imgs_show_list )):
152
+ video_output .write (imgs_show_list [k ][:,:,::- 1 ])
153
+ video_output .release ()
154
+ return []
155
+
156
+ CKPT_SIZE = {
157
+ 'faces.pkl' :[512 , 0.3 , 3 ],
158
+ 'horses.pkl' : [256 , 0.3 , 3 ],
159
+ 'elephants.pkl' : [512 , 0.4 , 4 ],
160
+ 'lions.pkl' :[512 , 0.4 , 4 ],
161
+ 'dogs.pkl' :[1024 , 0.4 , 4 ],
162
+ 'bicycles.pkl' :[256 , 0.3 , 3 ],
163
+ 'giraffes.pkl' :[512 , 0.4 , 4 ],
164
+ 'cats.pkl' :[512 , 0.3 , 3 ],
165
+ 'cars.pkl' :[512 , 0.3 , 3 ],
166
+ 'churches.pkl' :[256 , 0.3 , 3 ],
167
+ 'metfaces.pkl' :[1024 , 0.3 , 3 ],
168
+ }
169
+ RES_TO_CLICK_SIZE = {
170
+ 1024 : 10 ,
171
+ 512 : 5 ,
172
+ 256 : 3 ,
173
+ }
174
+
175
+ if torch .cuda .is_available ():
176
+ device = 'cuda'
177
+ else :
178
+ device = 'cpu'
179
+
180
+ demo = gr .Blocks ()
181
+
182
+ with demo :
183
+
184
+ points = gr .State ({'target' : [], 'handle' : []})
185
+ target_point = gr .State (False )
186
+ state = gr .State ({})
187
+
188
+ gr .Markdown (
189
+ """
190
+ # **FreeDrag**
191
+
192
+ Official implementation of [FreeDrag: Point Tracking is Not You Need for Interactive Point-based Image Editing](https://github.com/LPengYang/FreeDrag)
193
+
194
+
195
+ ## Parameter Description
196
+ **max_step**: max number of optimization step
197
+
198
+ **sample_interval**: the interval between sampled optimization step.
199
+ This parameter only affects the visualization of intermediate results and does not have any impact on the final outcome.
200
+ For high-resolution images(such as model of dog), a larger sample_interval can significantly accelerate the dragging process.
201
+
202
+ **Eepected initial loss and Max distance**: In the current version, both of these values are empirically set for each model.
203
+ Generally, for precise editing needs (e.g., merging eyes), smaller values are recommended, which may causes longer processing times.
204
+ Users can set these values according to practical editing requirements. We are currently seeking an automated solution.
205
+
206
+ **frame_rate**: the frame rate for saved video.
207
+
208
+ ## Hints
209
+ - Handle points (Blue): the point you want to drag.
210
+ - Target points (Red): the destination you want to drag towards to.
211
+ - **Localized points (Yellow)**: the localized points in sub-motion
212
+ """ ,
213
+ )
214
+
215
+ with gr .Row ():
216
+ with gr .Column (scale = 0.4 ):
217
+ with gr .Accordion ("Model" ):
218
+ with gr .Row ():
219
+ with gr .Column (min_width = 100 ):
220
+ seed = gr .Number (label = 'Seed' ,value = 0 )
221
+ with gr .Column (min_width = 100 ):
222
+ button_new = gr .Button ('Image Generate' , variant = 'primary' )
223
+ button_rand = gr .Button ('Rand Generate' )
224
+ model_name = gr .Dropdown (label = "Model name" ,choices = list (CKPT_SIZE .keys ()),value = list (CKPT_SIZE .keys ())[0 ])
225
+
226
+ with gr .Accordion ('Optional Parameters' ):
227
+ with gr .Row ():
228
+ with gr .Column (min_width = 100 ):
229
+ max_step = gr .Number (label = 'Max step' ,value = 2000 )
230
+ with gr .Column (min_width = 100 ):
231
+ sample_interval = gr .Number (label = 'Interval' ,value = 5 ,info = "Sampling interval" )
232
+
233
+ model_load , _ = load_model (model_name .value , device )
234
+ model = gr .State (ModelWrapper (model_load ,model_name .value ))
235
+ l_expected = gr .Slider (0.1 ,0.5 ,label = 'Eepected initial loss for each sub-motion' ,value = model .value .l ,step = 0.05 )
236
+ d_max = gr .Slider (1.0 ,6.0 ,label = 'Max distance for each sub-motion (in the feature map)' ,value = model .value .d ,step = 0.5 )
237
+
238
+ res = gr .State (model .value .res )
239
+ z1 = torch .from_numpy (np .random .RandomState (int (seed .value )).randn (1 , model .value .g .z_dim )).to (device )
240
+ label = torch .zeros ([1 , model .value .g .c_dim ], device = device )
241
+ ws_original = model .value .g .get_ws (z1 ,label ,truncation_psi = 0.7 )
242
+ latent = gr .State (ws_original )
243
+ add_watermark = gr .State (torch .zeros (1 ,device = device ))
244
+
245
+ _ , img_show_original = model .value .g .synthesis (ws = ws_original ,noise_mode = 'const' )
246
+
247
+ with gr .Accordion ('Video' ):
248
+ images_total = gr .State ([])
249
+ with gr .Row ():
250
+ with gr .Column (min_width = 100 ):
251
+ if_save_video = gr .Radio (["True" ,"False" ],value = "False" ,label = "if save video" )
252
+ with gr .Column (min_width = 100 ):
253
+ frame_rate = gr .Number (label = "Frame rate" ,value = 5 )
254
+ with gr .Row ():
255
+ with gr .Column (min_width = 100 ):
256
+ button_video = gr .Button ('Save video' , variant = 'primary' )
257
+
258
+ with gr .Accordion ('Drag' ):
259
+
260
+ with gr .Row ():
261
+ with gr .Column (min_width = 200 ):
262
+ reset_btn = gr .Button ('Reset points and mask' )
263
+ with gr .Row ():
264
+ with gr .Column (min_width = 100 ):
265
+ button_drag = gr .Button ('Drag it' , variant = 'primary' )
266
+ with gr .Column (min_width = 100 ):
267
+ button_stop = gr .Button ('Stop' )
268
+
269
+ progress = gr .Number (value = 0 , label = 'Steps' , interactive = False )
270
+
271
+ with gr .Column (scale = 0.53 ):
272
+ with gr .Tabs () as Tabs :
273
+ image_show = to_image (img_show_original )
274
+ image_clear = gr .State (image_show )
275
+ mask = gr .State (np .zeros_like (image_clear .value ))
276
+ with gr .Tab ('Setup Handle Points' , id = 'input' ) as imagetab :
277
+ image = gr .Image (image_show ).style (height = 768 , width = 768 )
278
+ with gr .Tab ('Draw a Mask' , id = 'mask' ) as masktab :
279
+ mask_show = gr .ImageMask (image_show ).style (height = 768 , width = 768 )
280
+
281
+ image .select (on_click , [image , target_point , points , res ], [image , target_point ]).then (on_show_save )
282
+
283
+ image .upload (image_inversion ,[image ,res ,model ],[latent ,image ,image_clear ,add_watermark ]).then (reset_all ,
284
+ inputs = [image_clear ,mask ,add_watermark ],outputs = [points ,target_point ,image ,mask_show ,mask ,add_watermark ])
285
+
286
+ button_drag .click (on_drag , inputs = [model , points , mask , max_step ,latent ,sample_interval ,l_expected ,d_max ,if_save_video ,add_watermark ], \
287
+ outputs = [image , progress , latent , image_clear , images_total , button_stop ])
288
+ button_stop .click (change_stop_state )
289
+
290
+ button_video .click (save_video ,inputs = [images_total ,frame_rate ],outputs = [images_total ])
291
+ reset_btn .click (reset_all ,inputs = [image_clear ,mask ,add_watermark ],outputs = [points ,target_point ,image ,mask_show ,mask ,add_watermark ]).then (on_show_save )
292
+
293
+ button_new .click (new_image , inputs = [model ,seed ],outputs = [image , image_clear , latent ,seed ]).then (reset_all ,
294
+ inputs = [image_clear ,mask ],outputs = [points ,target_point ,image ,mask_show ,mask ,add_watermark ])
295
+
296
+ button_rand .click (new_image , inputs = [model ],outputs = [image , image_clear , latent ,seed ]).then (reset_all ,
297
+ inputs = [image_clear ,mask ],outputs = [points ,target_point ,image ,mask_show ,mask ,add_watermark ])
298
+
299
+ model_name .change (new_model ,inputs = [model_name ],outputs = [model ,res ,l_expected ,d_max ]).then \
300
+ (new_image , inputs = [model ,seed ],outputs = [image , image_clear , latent ,seed ]).then \
301
+ (reset_all ,inputs = [image_clear ,mask ],outputs = [points ,target_point ,image ,mask_show ,mask ,add_watermark ])
302
+
303
+ imagetab .select (update_mask ,[image ,mask_show ],[image ,mask ])
304
+ masktab .select (on_select_mask_tab , inputs = [image ], outputs = [mask_show ])
305
+
306
+
307
+ if __name__ == "__main__" :
308
+
309
+ demo .queue (concurrency_count = 3 ,max_size = 20 ).launch (share = True )
0 commit comments