diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index f0708f18e799..006fa0a96857 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -1,3 +1,4 @@ +import inspect from typing import Optional, Tuple, Union import torch @@ -59,6 +60,12 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps) + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_kwrags = {} + if not accepts_eta: + extra_kwrags["eta"] = eta + for t in tqdm(self.scheduler.timesteps): if guidance_scale == 1.0: # guidance_scale of 1 means no guidance @@ -79,7 +86,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, eta=eta)["prev_sample"] + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwrags)["prev_sample"] # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index fef9c5ac790f..3814827eea7f 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -1,3 +1,5 @@ +import inspect + import torch from tqdm.auto import tqdm @@ -31,11 +33,17 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps) + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_kwrags = {} + if not accepts_eta: + extra_kwrags["eta"] = eta + for t in tqdm(self.scheduler.timesteps): # predict the noise residual noise_prediction = self.unet(latents, t)["sample"] # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_prediction, t, latents, eta)["prev_sample"] + latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwrags)["prev_sample"] # decode the image latents with the VAE image = self.vqvae.decode(latents) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index f3a4bc3e0253..2ef59c438889 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -116,7 +116,6 @@ def step( model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray], - **kwargs, ): if self.counter < len(self.prk_timesteps): return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)