Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] added broadcast_to_shape_from_left helper and Scheduler tests #864

Merged
merged 15 commits into from
Oct 25, 2022
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import FlaxSchedulerMixin
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
else:
from ..utils.dummy_flax_objects import * # noqa F403

Expand Down
15 changes: 6 additions & 9 deletions src/diffusers/schedulers/scheduling_ddim_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import jax.numpy as jnp

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
Expand Down Expand Up @@ -173,7 +173,9 @@ def _get_variance(self, timestep, prev_timestep, alphas_cumprod):

return variance

def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState:
def set_timesteps(
self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> DDIMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

Expand Down Expand Up @@ -211,9 +213,6 @@ def step(
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
key (`random.KeyArray`): a PRNG key.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class

Returns:
Expand Down Expand Up @@ -279,13 +278,11 @@ def add_noise(
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[:, None]
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None]
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)

noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
Expand Down
23 changes: 13 additions & 10 deletions src/diffusers/schedulers/scheduling_ddpm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jax import random

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
Expand Down Expand Up @@ -101,6 +101,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):

"""

@property
def has_state(self):
return True

@register_to_config
def __init__(
self,
Expand Down Expand Up @@ -129,11 +133,12 @@ def __init__(
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.one = jnp.array(1.0)

self.state = DDPMSchedulerState.create(num_train_timesteps=num_train_timesteps)

self.variance_type = variance_type
def create_state(self):
return DDPMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)

def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState:
def set_timesteps(
self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> DDPMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

Expand Down Expand Up @@ -214,7 +219,7 @@ def step(
"""
t = timestep

if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
else:
predicted_variance = None
Expand Down Expand Up @@ -267,13 +272,11 @@ def add_noise(
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)

noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
Expand Down
11 changes: 9 additions & 2 deletions src/diffusers/schedulers/scheduling_karras_ve_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
A reasonable range is [0.2, 80].
"""

@property
def has_state(self):
return True

@register_to_config
def __init__(
self,
Expand All @@ -97,10 +101,13 @@ def __init__(
s_min: float = 0.05,
s_max: float = 50,
):
self.state = KarrasVeSchedulerState.create()
pass

def create_state(self):
return KarrasVeSchedulerState.create()

def set_timesteps(
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> KarrasVeSchedulerState:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down
15 changes: 10 additions & 5 deletions src/diffusers/schedulers/scheduling_lms_discrete_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left


@flax.struct.dataclass
Expand Down Expand Up @@ -63,6 +63,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
"""

@property
def has_state(self):
return True

@register_to_config
def __init__(
self,
Expand All @@ -85,8 +89,10 @@ def __init__(
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)

def create_state(self):
self.state = LMSDiscreteSchedulerState.create(
num_train_timesteps=num_train_timesteps, sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
num_train_timesteps=self.config.num_train_timesteps,
sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
)

def get_lms_coefficient(self, state, order, t, current_order):
Expand All @@ -112,7 +118,7 @@ def lms_derivative(tau):
return integrated_coeff

def set_timesteps(
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> LMSDiscreteSchedulerState:
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down Expand Up @@ -199,8 +205,7 @@ def add_noise(
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sigma = state.sigmas[timesteps].flatten()
while len(sigma.shape) < len(noise.shape):
sigma = sigma[..., None]
sigma = broadcast_to_shape_from_left(sigma, noise.shape)

noisy_samples = original_samples + noise * sigma

Expand Down
10 changes: 5 additions & 5 deletions src/diffusers/schedulers/scheduling_pndm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import jax.numpy as jnp

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left


def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
Expand Down Expand Up @@ -168,6 +168,8 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, sha
the `FlaxPNDMScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
shape (`Tuple`):
the shape of the samples to be generated.
"""
offset = self.config.steps_offset

Expand Down Expand Up @@ -509,13 +511,11 @@ def add_noise(
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)

noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
Expand Down
26 changes: 18 additions & 8 deletions src/diffusers/schedulers/scheduling_sde_ve_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax import random

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left


@flax.struct.dataclass
Expand Down Expand Up @@ -80,6 +80,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
correct_steps (`int`): number of correction steps performed on a produced sample.
"""

@property
def has_state(self):
return True

@register_to_config
def __init__(
self,
Expand All @@ -90,12 +94,20 @@ def __init__(
sampling_eps: float = 1e-5,
correct_steps: int = 1,
):
state = ScoreSdeVeSchedulerState.create()
pass

self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps)
def create_state(self):
state = ScoreSdeVeSchedulerState.create()
return self.set_sigmas(
state,
self.config.num_train_timesteps,
self.config.sigma_min,
self.config.sigma_max,
self.config.sampling_eps,
)

def set_timesteps(
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None
) -> ScoreSdeVeSchedulerState:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down Expand Up @@ -193,8 +205,7 @@ def step_pred(
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods
diffusion = diffusion.flatten()
while len(diffusion.shape) < len(sample.shape):
diffusion = diffusion[:, None]
diffusion = broadcast_to_shape_from_left(diffusion, sample.shape)
drift = drift - diffusion**2 * model_output

# equation 6: sample noise for the diffusion term of
Expand Down Expand Up @@ -252,8 +263,7 @@ def step_correct(

# compute corrected sample: model_output term and noise term
step_size = step_size.flatten()
while len(step_size.shape) < len(sample.shape):
step_size = step_size[:, None]
step_size = broadcast_to_shape_from_left(step_size, sample.shape)
prev_sample_mean = sample + step_size * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise

Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/schedulers/scheduling_utils_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Tuple

import jax.numpy as jnp

Expand Down Expand Up @@ -41,3 +42,8 @@ class FlaxSchedulerMixin:
"""

config_name = SCHEDULER_CONFIG_NAME


def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
assert len(shape) >= x.ndim
return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape)
Loading