@@ -155,10 +155,11 @@ def ui(self, is_img2img):
155
155
])
156
156
model_dropdowns .append (model )
157
157
158
- def refresh_all_models (* dropdowns ):
158
+ def refresh_all_models (dropdowns ):
159
159
update_cn_models ()
160
160
updates = []
161
161
for dd in dropdowns :
162
+ dd = dd ["value" ] if isinstance (dd , dict ) else dd
162
163
if dd in cn_models :
163
164
selected = dd
164
165
else :
@@ -175,14 +176,16 @@ def refresh_all_models(*dropdowns):
175
176
def create_canvas (h , w ):
176
177
return np .zeros (shape = (h , w , 3 ), dtype = np .uint8 ) + 255
177
178
178
- canvas_width = gr .Slider (label = "Canvas Width" , minimum = 256 , maximum = 1024 , value = 512 , step = 1 )
179
- canvas_height = gr .Slider (label = "Canvas Height" , minimum = 256 , maximum = 1024 , value = 512 , step = 1 )
179
+ resize_mode = gr .Radio (choices = ["Scale to Fit" , "Just Resize" ], value = "Scale to Fit" , label = "Resize Mode" )
180
+ with gr .Row ():
181
+ canvas_width = gr .Slider (label = "Canvas Width" , minimum = 256 , maximum = 1024 , value = 512 , step = 64 )
182
+ canvas_height = gr .Slider (label = "Canvas Height" , minimum = 256 , maximum = 1024 , value = 512 , step = 64 )
180
183
create_button = gr .Button (label = "Start" , value = 'Open drawing canvas!' )
181
184
input_image = gr .Image (source = 'upload' , type = 'numpy' , tool = 'sketch' )
182
185
gr .Markdown (value = 'Change your brush width to make it thinner if you want to draw something.' )
183
186
184
187
create_button .click (fn = create_canvas , inputs = [canvas_height , canvas_width ], outputs = [input_image ])
185
- ctrls += (input_image , scribble_mode )
188
+ ctrls += (input_image , scribble_mode , resize_mode )
186
189
187
190
return ctrls
188
191
@@ -212,7 +215,7 @@ def restore_networks():
212
215
self .latest_network .restore (unet )
213
216
self .latest_network = None
214
217
215
- enabled , module , model , weight ,image , scribble_mode = args
218
+ enabled , module , model , weight ,image , scribble_mode , resize_mode = args
216
219
217
220
if not enabled :
218
221
restore_networks ()
@@ -262,8 +265,12 @@ def restore_networks():
262
265
263
266
control = torch .from_numpy (detected_map .copy ()).float ().cuda () / 255.0
264
267
control = rearrange (control , 'h w c -> c h w' )
265
- control = Resize (h if h > w else w , interpolation = InterpolationMode .BICUBIC )(control )
266
- control = CenterCrop ((h , w ))(control )
268
+
269
+ if resize_mode == "Scale to Fit" :
270
+ control = Resize (h if h > w else w , interpolation = InterpolationMode .BICUBIC )(control )
271
+ control = CenterCrop ((h , w ))(control )
272
+ else :
273
+ control = Resize ((h ,w ), interpolation = InterpolationMode .BICUBIC )(control )
267
274
268
275
self .control = control
269
276
# control = torch.stack([control for _ in range(bsz)], dim=0)
0 commit comments