Skip to content

Commit 81f11fa

Browse files
fix last frame (#116)
1 parent 4b8fc15 commit 81f11fa

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

scripts/animatediff_latent.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import torch
23
from modules import images, shared
34
from modules.devices import device, dtype_vae, torch_gc
@@ -25,13 +26,17 @@ def randomize(
2526
init_alpha[init_alpha < 0] = 0
2627

2728
if params.last_frame is not None:
29+
last_frame = params.last_frame
30+
if type(last_frame) == str:
31+
from modules.api.api import decode_base64_to_image
32+
last_frame = decode_base64_to_image(last_frame)
2833
# Get last_alpha
2934
last_alpha = [
3035
1 - pow(i, params.latent_power_last) / params.latent_scale_last
31-
for i in range(params.last_frame)
36+
for i in range(params.video_length)
3237
]
3338
last_alpha.reverse()
34-
logger.info(f"Randomizing last_latent according to {init_alpha}.")
39+
logger.info(f"Randomizing last_latent according to {last_alpha}.")
3540
last_alpha = torch.tensor(last_alpha, dtype=torch.float32, device=device)[
3641
:, None, None, None
3742
]
@@ -43,13 +48,18 @@ def randomize(
4348
scaling_factor = 1 / sum_alpha[mask_alpha]
4449
init_alpha[mask_alpha] *= scaling_factor
4550
last_alpha[mask_alpha] *= scaling_factor
51+
init_alpha[0] = 1
52+
init_alpha[-1] = 0
53+
last_alpha[0] = 0
54+
last_alpha[-1] = 1
4655

4756
# Calculate last_latent
48-
last_frame = params.last_frame
4957
if p.resize_mode != 3:
5058
last_frame = images.resize_image(
5159
p.resize_mode, last_frame, p.width, p.height
52-
)[None, ...]
60+
)
61+
last_frame = np.array(last_frame).astype(np.float32) / 255.0
62+
last_frame = np.moveaxis(last_frame, 2, 0)[None, ...]
5363
last_frame = torch.from_numpy(last_frame).to(device).to(dtype_vae)
5464
last_latent = images_tensor_to_samples(
5565
last_frame,
@@ -64,11 +74,10 @@ def randomize(
6474
size=(p.height // opt_f, p.width // opt_f),
6575
mode="bilinear",
6676
)
67-
6877
# Modify init_latent
6978
p.init_latent = (
7079
p.init_latent * init_alpha
71-
+ p.last_latent * last_alpha
80+
+ last_latent * last_alpha
7281
+ p.rng.next() * (1 - init_alpha - last_alpha)
7382
)
7483
else:

scripts/animatediff_ui.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ def refresh_models(*inputs):
164164
label="Optional latent scale for last frame",
165165
)
166166
self.params.last_frame = gr.Image(
167-
label="[Experiment] Optional last frame. Leave it blank if you do not need one."
167+
label="[Experiment] Optional last frame. Leave it blank if you do not need one.",
168+
type="pil",
168169
)
169170
with gr.Row():
170171
unload = gr.Button(

0 commit comments

Comments
 (0)