You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
seems like the entire unet is fine tuned for this task which sort of aligns to image.
deffinetune_unet(accelerator, pipeline, batch, use_offset_noise,
rescale_schedule, offset_noise_strength, unet, motion_mask,
P_mean=0.7, P_std=1.6):
pipeline.vae.eval()
pipeline.image_encoder.eval()
device=unet.devicedtype=pipeline.vae.dtypevae=pipeline.vae# Convert videos to latent spacepixel_values=batch['pixel_values']
bsz, num_frames=pixel_values.shape[:2]
frames=rearrange(pixel_values, 'b f c h w-> (b f) c h w').to(dtype)
latents=vae.encode(frames).latent_dist.mode() *vae.config.scaling_factorlatents=rearrange(latents, '(b f) c h w-> b f c h w', b=bsz)
# enocde image latentimage=pixel_values[:,0].to(dtype)
noise_aug_strength=math.exp(random.normalvariate(mu=-3, sigma=0.5))
image=image+noise_aug_strength*torch.randn_like(image)
image_latent=vae.encode(image).latent_dist.mode() *vae.config.scaling_factorifmotion_mask:
mask=batch['mask']
mask=mask.div(255)
h, w=latents.shape[-2:]
mask=T.Resize((h, w), antialias=False)(mask)
mask[mask<0.5] =0mask[mask>=0.5] =1mask=repeat(mask, 'b h w -> b f 1 h w', f=num_frames).detach().clone()
mask[:,0] =0freeze=repeat(image_latent, 'b c h w -> b f c h w', f=num_frames)
condition_latent=latents* (1-mask) +freeze*maskelse:
condition_latent=repeat(image_latent, 'b c h w->b f c h w',f=num_frames)
pipeline.image_encoder.to(device, dtype=dtype)
images=_resize_with_antialiasing(pixel_values[:,0], (224, 224)).to(dtype)
images= (images+1.0) /2.0# [-1, 1] -> [0, 1]images=pipeline.feature_extractor(
images=images,
do_normalize=True,
do_center_crop=False,
do_resize=False,
do_rescale=False,
return_tensors="pt",
).pixel_valuesimage_embeddings=pipeline._encode_image(images, device, 1, False)
encoder_hidden_states=image_embeddingsuncond_hidden_states=torch.zeros_like(image_embeddings)
ifrandom.random() <0.15:
encoder_hidden_states=uncond_hidden_states# Add noise to the latents according to the noise magnitude at each timestep# (this is the forward diffusion process) #[bsz, f, c, h , w]rnd_normal=torch.randn([bsz, 1, 1, 1, 1], device=device)
sigma= (rnd_normal*P_std+P_mean).exp()
c_skip=1/ (sigma**2+1)
c_out=-sigma/ (sigma**2+1) **0.5c_in=1/ (sigma**2+1) **0.5c_noise= (sigma.log() /4).reshape([bsz])
loss_weight= (sigma**2+1) /sigma**2noisy_latents=latents+torch.randn_like(latents) *sigmainput_latents=torch.cat([c_in*noisy_latents,
condition_latent/vae.config.scaling_factor], dim=2)
ifmotion_mask:
input_latents=torch.cat([mask, input_latents], dim=2)
motion_bucket_id=127fps=7added_time_ids=pipeline._get_add_time_ids(fps, motion_bucket_id,
noise_aug_strength, image_embeddings.dtype, bsz, 1, False)
added_time_ids=added_time_ids.to(device)
loss=0accelerator.wait_for_everyone()
model_pred=unet(input_latents, c_noise, encoder_hidden_states=encoder_hidden_states, added_time_ids=added_time_ids).samplepredict_x0=c_out*model_pred+c_skip*noisy_latentsloss+= ((predict_x0-latents)**2*loss_weight).mean()
ifmotion_mask:
loss+=F.mse_loss(predict_x0*(1-mask), condition_latent*(1-mask))
returnloss
https://github.com/alibaba/animate-anything/blob/main/train_svd.py
seems like the entire unet is fine tuned for this task which sort of aligns to image.
UPDATE - maybe we can just use Alibaba's pretrained 6.8gb model.
https://github.com/alibaba/animate-anything?tab=readme-ov-file
just need to wire into
motion_bucket_id: int = 127,
MaskStableVideoDiffusionPipeline
The text was updated successfully, but these errors were encountered: