Skip to content

Commit

Permalink
[Dance Diffusion] Add dance diffusion (huggingface#803)
Browse files Browse the repository at this point in the history
* start

* add more logic

* Update src/diffusers/models/unet_2d_condition_flax.py

* match weights

* up

* make model work

* making class more general, fixing missed file rename

* small fix

* make new conversion work

* up

* finalize conversion

* up

* first batch of variable renamings

* remove c and c_prev var names

* add mid and out block structure

* add pipeline

* up

* finish conversion

* finish

* upload

* more fixes

* Apply suggestions from code review

* add attr

* up

* uP

* up

* finish tests

* finish

* uP

* finish

* fix test

* up

* naming consistency in tests

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Nathan Lambert <nathan@huggingface.co>
Co-authored-by: Anton Lozhkov <anton@huggingface.co>

* remove hardcoded 16

* Remove bogus

* fix some stuff

* finish

* improve logging

* docs

* upload

Co-authored-by: Nathan Lambert <nol@berkeley.edu>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Nathan Lambert <nathan@huggingface.co>
Co-authored-by: Anton Lozhkov <anton@huggingface.co>
  • Loading branch information
6 people authored Oct 25, 2022
1 parent fc89d4a commit 92dd118
Show file tree
Hide file tree
Showing 18 changed files with 917 additions and 12 deletions.
13 changes: 11 additions & 2 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
Expand All @@ -29,10 +29,19 @@
get_scheduler,
)
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
from .pipelines import (
DanceDiffusionPipeline,
DDIMPipeline,
DDPMPipeline,
KarrasVePipeline,
LDMPipeline,
PNDMPipeline,
ScoreSdeVePipeline,
)
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
IPNDMScheduler,
KarrasVeScheduler,
PNDMScheduler,
SchedulerMixin,
Expand Down
1 change: 1 addition & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


if is_torch_available():
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel
Expand Down
23 changes: 17 additions & 6 deletions models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,28 @@ def forward(self, timesteps):
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""

def __init__(self, embedding_size: int = 256, scale: float = 1.0):
def __init__(
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
):
super().__init__()
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.log = log
self.flip_sin_to_cos = flip_sin_to_cos

# to delete later
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
if set_W_to_weight:
# to delete later
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)

self.weight = self.W
self.weight = self.W

def forward(self, x):
x = torch.log(x)
if self.log:
x = torch.log(x)

x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

if self.flip_sin_to_cos:
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
else:
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out
172 changes: 172 additions & 0 deletions models/unet_1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_1d_blocks import get_down_block, get_mid_block, get_up_block


@dataclass
class UNet1DOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
Hidden states output. Output of last layer of model.
"""

sample: torch.FloatTensor


class UNet1DModel(ModelMixin, ConfigMixin):
r"""
UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime.
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to :
obj:`False`): Whether to flip sin to cos for fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(32, 32, 64)`): Tuple of block output channels.
"""

@register_to_config
def __init__(
self,
sample_size: int = 65536,
sample_rate: Optional[int] = None,
in_channels: int = 2,
out_channels: int = 2,
extra_in_channels: int = 0,
time_embedding_type: str = "fourier",
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
mid_block_type: str = "UNetMidBlock1D",
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
block_out_channels: Tuple[int] = (32, 32, 64),
):
super().__init__()

self.sample_size = sample_size

# time
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]

if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
self.out_block = None

# down
output_channel = in_channels
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]

if i == 0:
input_channel += extra_in_channels

down_block = get_down_block(
down_block_type,
in_channels=input_channel,
out_channels=output_channel,
)
self.down_blocks.append(down_block)

# mid
self.mid_block = get_mid_block(
mid_block_type=mid_block_type,
mid_channels=block_out_channels[-1],
in_channels=block_out_channels[-1],
out_channels=None,
)

# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else out_channels

up_block = get_up_block(
up_block_type,
in_channels=prev_output_channel,
out_channels=output_channel,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel

# TODO(PVP, Nathan) placeholder for RL application to be merged shortly
# Totally fine to add another layer with a if statement - no need for nn.Identity here

def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet1DOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): `(batch_size, sample_size, num_channels)` noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
Returns:
[`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
# 1. time
if len(timestep.shape) == 0:
timestep = timestep[None]

timestep_embed = self.time_proj(timestep)[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]])

# 2. down
down_block_res_samples = ()
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
down_block_res_samples += res_samples

# 3. mid
sample = self.mid_block(sample)

# 4. up
for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-1:]
down_block_res_samples = down_block_res_samples[:-1]
sample = upsample_block(sample, res_samples)

if not return_dict:
return (sample,)

return UNet1DOutput(sample=sample)
Loading

0 comments on commit 92dd118

Please sign in to comment.