diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index 58d814796ede..6af11e887fb6 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -121,6 +121,10 @@ from .artist.tokenizer import * from .dallebart.modeling import * from .dallebart.tokenizer import * +from .clip.modeling import * +from .clip.feature_extraction import * +from .clip.tokenizer import * +from .clip.procesing import * from .gptj.modeling import * from .gptj.tokenizer import * diff --git a/paddlenlp/transformers/auto/modeling.py b/paddlenlp/transformers/auto/modeling.py index 706074d91be5..97aa79a20484 100644 --- a/paddlenlp/transformers/auto/modeling.py +++ b/paddlenlp/transformers/auto/modeling.py @@ -88,6 +88,7 @@ ("Bart", "bart"), ("GAUAlpha", "gau_alpha"), ("CodeGen", "codegen"), + ("CLIP", "clip"), ("Artist", "artist"), ("OPT", 'opt') ]) diff --git a/paddlenlp/transformers/auto/tokenizer.py b/paddlenlp/transformers/auto/tokenizer.py index a6d056bbcb4f..4b759ea8b659 100644 --- a/paddlenlp/transformers/auto/tokenizer.py +++ b/paddlenlp/transformers/auto/tokenizer.py @@ -79,6 +79,7 @@ ("BartTokenizer", "bart"), ("GAUAlphaTokenizer", "gau_alpha"), ("CodeGenTokenizer", "codegen"), + ("CLIPTokenizer", "clip"), ("ArtistTokenizer", "artist"), ]) diff --git a/paddlenlp/transformers/clip/__init__.py b/paddlenlp/transformers/clip/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/paddlenlp/transformers/clip/feature_extraction.py b/paddlenlp/transformers/clip/feature_extraction.py new file mode 100644 index 000000000000..d1c9e68646ea --- /dev/null +++ b/paddlenlp/transformers/clip/feature_extraction.py @@ -0,0 +1,166 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for CLIP.""" + +from typing import List, Optional, Union + +import paddle +import numpy as np +from PIL import Image + +from ..feature_extraction_utils import BatchFeature +from ..tokenizer_utils_base import TensorType +from ..image_utils import ImageFeatureExtractionMixin + +__all__ = ["CLIPFeatureExtractor"] + + +class CLIPFeatureExtractor(ImageFeatureExtractionMixin): + r""" + Constructs a CLIP feature extractor. + This feature extractor inherits from [`ImageFeatureExtractionMixin`] which contains most of the main methods. Users + should refer to this superclass for more information regarding those methods. + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the input to a certain `size`. + size (`int`, *optional*, defaults to 224): + Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. + resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`): + An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`, + `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect + if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the + image is padded with 0's and then center cropped. + crop_size (`int`, *optional*, defaults to 224): + Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input with `image_mean` and `image_std`. + image_mean (`List[int]`, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + The sequence of means for each channel, to be used when normalizing images. + image_std (`List[int]`, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + The sequence of standard deviations for each channel, to be used when normalizing images. + convert_rgb (`bool`, defaults to `True`): + Whether or not to convert `PIL.Image.Image` into `RGB` format + """ + + model_input_names = ["pixel_values"] + + def __init__(self, + do_resize=True, + size=224, + resample=Image.BICUBIC, + do_center_crop=True, + crop_size=224, + do_normalize=True, + image_mean=None, + image_std=None, + do_convert_rgb=True, + **kwargs): + super().__init__() + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else [ + 0.48145466, 0.4578275, 0.40821073 + ] + self.image_std = image_std if image_std is not None else [ + 0.26862954, 0.26130258, 0.27577711 + ] + self.do_convert_rgb = do_convert_rgb + + def __call__( + self, + images: Union[Image.Image, np.ndarray, "paddle.Tensor", + List[Image.Image], List[np.ndarray], + List["paddle.Tensor"] # noqa + ], + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs): + """ + Main method to prepare for the model one or several image(s). + + NumPy arrays and Paddle tensors are converted to PIL images when resizing, so the most efficient is to pass + PIL images. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `paddle.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[paddle.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or Paddle + tensor. In case of a NumPy array/Paddle tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): + If set, will return tensors of a particular framework. Acceptable values are: + - `'pd'`: Return Paddle `paddle.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + - **pixel_values** -- Pixel values to be fed to a model. + """ + # Input type checking for clearer error + valid_images = False + + # Check that images has a valid type + if isinstance(images, + (Image.Image, np.ndarray)) or paddle.is_tensor(images): + valid_images = True + elif isinstance(images, (list, tuple)): + if len(images) == 0 or isinstance( + images[0], + (Image.Image, np.ndarray)) or paddle.is_tensor(images[0]): + valid_images = True + + if not valid_images: + raise ValueError( + "Images must of type `PIL.Image.Image`, `np.ndarray` or `paddle.Tensor` (single example), " + "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[paddle.Tensor]` (batch of examples)." + ) + + is_batched = bool( + isinstance(images, (list, tuple)) + and (isinstance(images[0], (Image.Image, np.ndarray)) + or paddle.is_tensor(images[0]))) + + if not is_batched: + images = [images] + + # transformations (convert rgb + resizing + center cropping + normalization) + if self.do_convert_rgb: + images = [self.convert_rgb(image) for image in images] + if self.do_resize and self.size is not None and self.resample is not None: + images = [ + self.resize(image=image, + size=self.size, + resample=self.resample, + default_to_square=False) for image in images + ] + if self.do_center_crop and self.crop_size is not None: + images = [ + self.center_crop(image, self.crop_size) for image in images + ] + if self.do_normalize: + images = [ + self.normalize(image=image, + mean=self.image_mean, + std=self.image_std) for image in images + ] + + # return as BatchFeature + data = {"pixel_values": images} + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + return encoded_inputs diff --git a/paddlenlp/transformers/clip/modeling.py b/paddlenlp/transformers/clip/modeling.py new file mode 100644 index 000000000000..95081f0a315d --- /dev/null +++ b/paddlenlp/transformers/clip/modeling.py @@ -0,0 +1,1891 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2021 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Tuple, Optional, Union + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from dataclasses import dataclass +from ..model_outputs import BaseModelOutputWithPoolingAndCrossAttentions, ModelOutput +from .. import PretrainedModel, register_base_model +from ..stable_diffusion_utils import StableDiffusionMixin, AutoencoderKL, PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, UNet2DConditionModel +from ..guided_diffusion_utils import DiscoDiffusionMixin, create_gaussian_diffusion, create_unet_model, create_secondary_model + +__all__ = [ + 'VisionTransformer', + 'TextTransformer', + 'CLIPTextModel', + 'CLIPVisionModel', + 'CLIPPretrainedModel', + 'CLIPModel', + 'CLIPForImageGeneration', + 'ModifiedResNet', +] + + +@dataclass +class CLIPOutput(ModelOutput): + """ + Args: + loss (`paddle.Tensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image:(`paddle.Tensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text:(`paddle.Tensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds(`paddle.Tensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`TextTransformer`]. + image_embeds(`paddle.Tensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`VisionTransformer`]. + text_model_output(:class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions`): + The output of the [`TextTransformer`]. + vision_model_output(:class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions`): + The output of the [`VisionTransformer`]. + """ + + loss: Optional[paddle.Tensor] = None + logits_per_image: paddle.Tensor = None + logits_per_text: paddle.Tensor = None + text_embeds: paddle.Tensor = None + image_embeds: paddle.Tensor = None + text_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None + vision_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output" + ] else getattr(self, k).to_tuple() + for k in self.keys()) + + +# contrastive loss function, adapted from +# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html +def contrastive_loss(logits): + return F.cross_entropy(logits, paddle.arange(len(logits))) + + +def clip_loss(similarity): + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + +class ModifiedResNet(nn.Layer): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, + layers, + output_dim, + heads, + input_resolution=224, + width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2D(3, + width // 2, + kernel_size=3, + stride=2, + padding=1, + bias_attr=False) + self.bn1 = nn.BatchNorm2D(width // 2) + self.conv2 = nn.Conv2D(width // 2, + width // 2, + kernel_size=3, + padding=1, + bias_attr=False) + self.bn2 = nn.BatchNorm2D(width // 2) + self.conv3 = nn.Conv2D(width // 2, + width, + kernel_size=3, + padding=1, + bias_attr=False) + self.bn3 = nn.BatchNorm2D(width) + self.avgpool = nn.AvgPool2D(2) + self.relu = nn.ReLU() + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, + heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +def multi_head_attention_forward(x: paddle.Tensor, + num_heads: int, + q_proj: nn.Linear, + k_proj: nn.Linear, + v_proj: nn.Linear, + c_proj: nn.Linear, + attn_mask: Optional[paddle.Tensor] = None): + max_len, batch_size, emb_dim = x.shape + head_dim = emb_dim // num_heads + scaling = float(head_dim)**-0.5 + q = q_proj(x) # L, N, E + k = k_proj(x) # L, N, E + v = v_proj(x) # L, N, E + + v = v.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2)) + k = k.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2)) + q = q.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2)) + + q = q * scaling + qk = paddle.matmul(q, k, transpose_y=True) + if attn_mask is not None: + if attn_mask.ndim == 2: + attn_mask.unsqueeze_(0) + assert attn_mask.shape[0] == 1 and attn_mask.shape[ + 1] == max_len and attn_mask.shape[2] == max_len + qk += attn_mask + + qk = F.softmax(qk, axis=-1) + atten = paddle.bmm(qk, v) + atten = atten.transpose((1, 0, 2)) + atten = atten.reshape((max_len, batch_size, emb_dim)) + atten = c_proj(atten) + return atten + + +class Identity(nn.Layer): + + def __init__(self): + super().__init__() + + def forward(self, x): + return x + + +class Bottleneck(nn.Layer): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2D(inplanes, planes, 1, bias_attr=False) + self.bn1 = nn.BatchNorm2D(planes) + + self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False) + self.bn2 = nn.BatchNorm2D(planes) + + self.avgpool = nn.AvgPool2D(stride) if stride > 1 else Identity() + + self.conv3 = nn.Conv2D(planes, + planes * self.expansion, + 1, + bias_attr=False) + self.bn3 = nn.BatchNorm2D(planes * self.expansion) + + self.relu = nn.ReLU() + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + self.downsample = nn.Sequential( + ("-1", nn.AvgPool2D(stride)), + ("0", + nn.Conv2D(inplanes, + planes * self.expansion, + 1, + stride=1, + bias_attr=False)), + ("1", nn.BatchNorm2D(planes * self.expansion))) + + def forward(self, x): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Layer): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super().__init__() + + self.positional_embedding = nn.Embedding(spacial_dim**2 + 1, embed_dim) + + self.q_proj = nn.Linear(embed_dim, embed_dim, bias_attr=True) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias_attr=True) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias_attr=True) + self.c_proj = nn.Linear(embed_dim, + output_dim or embed_dim, + bias_attr=True) + self.num_heads = num_heads + + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + + def forward(self, x): + + x = x.reshape( + (x.shape[0], x.shape[1], x.shape[2] * x.shape[3])).transpose( + (2, 0, 1)) # NCHW -> (HW)NC + x = paddle.concat([x.mean(axis=0, keepdim=True), x], axis=0) + x = x + paddle.unsqueeze(self.positional_embedding.weight, 1) + out = multi_head_attention_forward(x, self.num_heads, self.q_proj, + self.k_proj, self.v_proj, + self.c_proj) + + return out[0] + + +class CLIPPretrainedModel(PretrainedModel): + """ + An abstract class for pretrained CLIP models. It provides CLIP related + `model_config_file`, `pretrained_init_configuration`, `resource_files_names`, + `pretrained_resource_files_map`, `base_model_prefix` for downloading and + loading pretrained models. + See :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details. + """ + model_config_file = "model_config.json" + pretrained_init_configuration = { + "openai/clip-vit-base-patch32": { + # vision + "image_resolution": 224, + "vision_layers": 12, + "vision_heads": 12, + "vision_mlp_ratio": 4, + "vision_embed_dim": 768, + "vision_patch_size": 32, + "vision_hidden_act": "quick_gelu", + # text + "max_text_length": 77, + "vocab_size": 49408, + "text_embed_dim": 512, + "text_heads": 8, + "text_layers": 12, + "text_hidden_act": "quick_gelu", + # others + "projection_dim": 512, + "initializer_range": 0.02, + "logit_scale_init_value": 2.6592 + }, + "openai/clip-rn50": { + # vision + "image_resolution": 224, + "vision_layers": [3, 4, 6, 3], + "vision_heads": 32, + "vision_mlp_ratio": None, # do not use + "vision_embed_dim": 64, # vision width + "vision_patch_size": None, # do not use + "vision_hidden_act": None, # do not use + # text + "max_text_length": 77, + "vocab_size": 49408, + "text_embed_dim": 512, + "text_heads": 8, + "text_layers": 12, + "text_hidden_act": "quick_gelu", + # others + "projection_dim": 1024, + "initializer_range": 0.02, + "logit_scale_init_value": 2.6592 + }, + "openai/clip-rn101": { + # vision + "image_resolution": 224, + "vision_layers": [3, 4, 23, 3], + "vision_heads": 32, + "vision_mlp_ratio": None, # do not use + "vision_embed_dim": 64, # vision width + "vision_patch_size": None, # do not use + "vision_hidden_act": None, # do not use + # text + "max_text_length": 77, + "vocab_size": 49408, + "text_embed_dim": 512, + "text_heads": 8, + "text_layers": 12, + "text_hidden_act": "quick_gelu", + # others + "projection_dim": 512, + "initializer_range": 0.02, + "logit_scale_init_value": 2.6592 + }, + "openai/clip-vit-large-patch14": { + # vision + "image_resolution": 224, + "vision_layers": 24, + "vision_heads": 16, + "vision_mlp_ratio": 4, + "vision_embed_dim": 1024, + "vision_patch_size": 14, + "vision_hidden_act": "quick_gelu", + # text + "max_text_length": 77, + "vocab_size": 49408, + "text_embed_dim": 768, + "text_heads": 12, + "text_layers": 12, + "text_hidden_act": "quick_gelu", + # others + "projection_dim": 768, + "initializer_range": 0.02, + "logit_scale_init_value": 2.6592 + }, + } + resource_files_names = {"model_state": "model_state.pdparams"} + pretrained_resource_files_map = { + "model_state": { + "openai/clip-vit-base-patch32": + "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-vit-base-patch32/model_state.pdparams", + "openai/clip-rn50": + "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-rn50/model_state.pdparams", + "openai/clip-rn101": + "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-rn101/model_state.pdparams", + "openai/clip-vit-large-patch14": + "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-vit-large-patch14/model_state.pdparams", + } + } + base_model_prefix = "clip" + + def _init_weights(self, layer): + """Initialize the weights""" + initializer_range = self.initializer_range if hasattr( + self, + "initializer_range") else self.clip.config["initializer_range"] + factor = self.initializer_factor if hasattr( + self, + "initializer_factor") else self.clip.config["initializer_factor"] + + if isinstance(layer, VisionTransformer): + vision_embed_dim = self.vision_embed_dim if hasattr( + self, + "vision_embed_dim") else self.clip.config["vision_embed_dim"] + vision_layers = self.vision_layers if hasattr( + self, "vision_layers") else self.clip.config["vision_layers"] + # vision embedding + layer.class_embedding.set_value( + paddle.normal( + std=vision_embed_dim**-0.5 * factor, + shape=layer.class_embedding.shape, + )) + layer.conv1.weight.set_value( + paddle.normal( + std=initializer_range * factor, + shape=layer.conv1.weight.shape, + )) + layer.positional_embedding.weight.set_value( + paddle.normal( + std=initializer_range * factor, + shape=layer.positional_embedding.weight.shape, + )) + + elif isinstance(layer, TextTransformer): + text_embed_dim = self.text_embed_dim if hasattr( + self, "text_embed_dim") else self.clip.config["text_embed_dim"] + text_layers = self.text_layers if hasattr( + self, "text_layers") else self.clip.config["text_layers"] + # text embedding + layer.token_embedding.weight.set_value( + paddle.normal( + mean=0.0, + std=factor * 0.02, + shape=layer.token_embedding.weight.shape, + )) + layer.positional_embedding.weight.set_value( + paddle.normal( + mean=0.0, + std=factor * 0.02, + shape=layer.positional_embedding.weight.shape, + )) + elif isinstance(layer, CLIPModel): + vision_embed_dim = self.vision_embed_dim if hasattr( + self, + "vision_embed_dim") else self.clip.config["vision_embed_dim"] + vision_layers = self.vision_layers if hasattr( + self, "vision_layers") else self.clip.config["vision_layers"] + text_embed_dim = self.text_embed_dim if hasattr( + self, "text_embed_dim") else self.clip.config["text_embed_dim"] + text_layers = self.text_layers if hasattr( + self, "text_layers") else self.clip.config["text_layers"] + layer.text_projection.set_value( + paddle.normal( + std=text_embed_dim**-0.5 * factor, + shape=layer.text_projection.shape, + )) + if hasattr(layer, "vision_projection"): + layer.vision_projection.set_value( + paddle.normal( + std=vision_embed_dim**-0.5 * factor, + shape=layer.vision_projection.shape, + )) + for name, sub_layer in layer.named_sublayers(): + num_layers = vision_layers if "vision_model" in name else text_layers + if isinstance(sub_layer, nn.TransformerEncoderLayer): + # self_attn + in_proj_std = (sub_layer.self_attn.embed_dim**-0.5) * ( + (2 * num_layers)**-0.5) * factor + out_proj_std = (sub_layer.self_attn.embed_dim** + -0.5) * factor + sub_layer.self_attn.q_proj.weight.set_value( + paddle.normal( + std=in_proj_std, + shape=sub_layer.self_attn.q_proj.weight.shape, + )) + sub_layer.self_attn.k_proj.weight.set_value( + paddle.normal( + std=in_proj_std, + shape=sub_layer.self_attn.k_proj.weight.shape, + )) + sub_layer.self_attn.v_proj.weight.set_value( + paddle.normal( + std=in_proj_std, + shape=sub_layer.self_attn.v_proj.weight.shape, + )) + sub_layer.self_attn.out_proj.weight.set_value( + paddle.normal( + std=out_proj_std, + shape=sub_layer.self_attn.out_proj.weight.shape, + )) + # ffn + in_proj_std = ((sub_layer._config["d_model"]**-0.5) * + ((2 * num_layers)**-0.5) * factor) + fc_std = (2 * sub_layer._config["d_model"])**-0.5 * factor + sub_layer.linear1.weight.set_value( + paddle.normal( + std=fc_std, + shape=sub_layer.linear1.weight.shape, + )) + sub_layer.linear2.weight.set_value( + paddle.normal( + std=in_proj_std, + shape=sub_layer.linear2.weight.shape, + )) + if isinstance(layer, nn.LayerNorm): + layer.bias.set_value(paddle.zeros_like(layer.bias)) + layer.weight.set_value(paddle.ones_like(layer.weight)) + if isinstance(layer, nn.Linear) and layer.bias is not None: + layer.bias.set_value(paddle.zeros_like(layer.bias)) + + +# set attr +def quick_gelu(x): + return x * F.sigmoid(1.702 * x) + + +F.quick_gelu = quick_gelu + +NEG_INF = float("-inf") # -1e4 -1e9 + + +class VisionTransformer(nn.Layer): + + def __init__(self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + activation: str, + mlp_ratio: int, + normalize_before: bool = True): + super().__init__() + self.input_resolution = input_resolution + # used patch_size x patch_size, stride patch_size to do linear projection + self.conv1 = nn.Conv2D(in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias_attr=False) + + self.class_embedding = paddle.create_parameter( + (width, ), paddle.get_default_dtype()) + + self.positional_embedding = nn.Embedding( + (input_resolution // patch_size)**2 + 1, width) + + self.ln_pre = nn.LayerNorm(width) + encoder_layer = nn.TransformerEncoderLayer( + d_model=width, + nhead=heads, + dim_feedforward=width * mlp_ratio, + normalize_before=normalize_before, + dropout=0, + activation=activation, + attn_dropout=0, + act_dropout=0) + self.transformer = nn.TransformerEncoder(encoder_layer, layers) + + self.ln_post = nn.LayerNorm(width) + + def forward(self, + pixel_values, + output_attentions=False, + output_hidden_states=False, + return_dict=False): + pixel_values = self.conv1(pixel_values) + pixel_values = pixel_values.reshape( + (pixel_values.shape[0], pixel_values.shape[1], -1)) + pixel_values = pixel_values.transpose((0, 2, 1)) + embedding_output = paddle.concat([ + self.class_embedding.unsqueeze([0, 1]).expand( + [pixel_values.shape[0], -1, -1]), pixel_values + ], + axis=1) + embedding_output = embedding_output + self.positional_embedding.weight + embedding_output = self.ln_pre(embedding_output) + encoder_outputs = self.transformer( + embedding_output, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + if isinstance(encoder_outputs, type(embedding_output)): + last_hidden_state = encoder_outputs + else: + last_hidden_state = encoder_outputs[0] + + pooled_output = self.ln_post(last_hidden_state[:, 0]) + + if isinstance(encoder_outputs, type(embedding_output)): + return (last_hidden_state, pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions) + + +class TextTransformer(nn.Layer): + + def __init__(self, + context_length, + transformer_width, + transformer_heads, + transformer_layers, + vocab_size, + activation="quick_gelu", + normalize_before=True): + super().__init__() + encoder_layer = nn.TransformerEncoderLayer( + d_model=transformer_width, + nhead=transformer_heads, + dim_feedforward=transformer_width * 4, + normalize_before=normalize_before, + dropout=0, + activation=activation, + attn_dropout=0, + act_dropout=0) + self.transformer = nn.TransformerEncoder(encoder_layer, + transformer_layers) + + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Embedding(context_length, + transformer_width) + self.ln_final = nn.LayerNorm(transformer_width) + + self.register_buffer("causal_mask", + paddle.triu(paddle.ones( + (1, 1, context_length, context_length)) * + NEG_INF, + diagonal=1), + persistable=False) + self.register_buffer("position_ids", + paddle.arange(context_length).reshape((1, -1)), + persistable=False) + + def forward(self, + input_ids=None, + attention_mask=None, + position_ids=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False): + bs, seqlen = input_ids.shape[:2] + if position_ids is None: + position_ids = self.position_ids[:, :seqlen].astype("int64") + + causal_mask = self.causal_mask[:, :, :seqlen, :seqlen] + if attention_mask is not None: + assert attention_mask.ndim == 2 + expanded_mask = attention_mask[:, None, None, :].expand( + [bs, 1, seqlen, -1]).astype(causal_mask.dtype) + inverted_mask = (1.0 - expanded_mask) * NEG_INF + attention_mask = inverted_mask + causal_mask + else: + attention_mask = causal_mask + attention_mask.stop_gradient = True + + embedding_output = self.token_embedding( + input_ids) + self.positional_embedding( + position_ids) # [batch_size, n_ctx, d_model] + + encoder_outputs = self.transformer( + embedding_output, + src_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + if isinstance(encoder_outputs, type(embedding_output)): + last_hidden_state = encoder_outputs + else: + last_hidden_state = encoder_outputs[0] + + last_hidden_state = self.ln_final(last_hidden_state) + pooled_output = last_hidden_state.gather_nd( + paddle.stack( + [paddle.arange(bs), input_ids.argmax(-1)], axis=-1)) + + if isinstance(encoder_outputs, type(embedding_output)): + return (last_hidden_state, pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions) + + +@register_base_model +class CLIPModel(CLIPPretrainedModel): + r""" + The bare CLIP Model outputting logits_per_image and logits_per_text. + This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`. + Refer to the superclass documentation for the generic methods. + This model is also a Paddle `paddle.nn.Layer `__ subclass. Use it as a regular Paddle Layer + and refer to the Paddle documentation for all matter related to general usage and behavior. + + Args: + image_resolution (int, optional): + The size (resolution) of each image. + Defaults to `224`. + vision_layers (int, optional): + Number of hidden layers in the vision model. + Defaults to `12`. + vision_heads (int, optional): + Number of attention heads for each attention layer in the vision attention. + Defaults to `12`. + vision_embed_dim (int, optional): + Dimensionality of the embedding layer and encoder layers in vision model. + Defaults to `768`. + vision_patch_size(int, optional): + The size (resolution) of each patch. + Defaults to `32`. + vision_mlp_ratio(int, optional): + The ratio between dim_feedforward and vision_hidden_dim. `radio = dim_feedforward/vision_hidden_dim` + Defaults to `4`. + vision_hidden_act (str, optional): + The non-linear activation function of the ffn layer in the vision model. + ``"gelu"``, ``"relu"``, ``"quick_gelu"`` and any other paddle supported activation functions are supported. + Defaults to `"quick_gelu"`. + max_text_length (int, optional): + The maximum value of the dimensionality of text position encoding, which dictates the maximum supported length of the text + input sequence. Defaults to `64`. + vocab_size (int, optional): + Vocabulary size of `inputs_ids` in `CLIPModel`. Also is the vocab size of text token embedding matrix. + Defaults to `49408`. + text_embed_dim (int, optional): + Dimensionality of the embedding layer and encoder layers in text model. + Defaults to `768`. + text_heads (int, optional): + Number of attention heads for each attention layer in the text attention. + Defaults to `8`. + text_layers (int, optional): + Number of hidden layers in the text model. + Defaults to `12`. + text_hidden_act (str, optional): + The non-linear activation function of the ffn layer in the text model. + ``"gelu"``, ``"relu"``, ``"quick_gelu"`` and any other paddle supported activation functions are supported. + Defaults to `"quick_gelu"`. + projection_dim (int, optional): + Dimentionality of text and vision projection layers. + Defaults to `512`. + initializer_range (float, optional): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + Default to `0.02`. + initializer_factor (float, optional): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). Default to `1.`. + logit_scale_init_value (float, optional): + The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation. + Default to `2.6592`. + + """ + + def __init__( + self, + # vision + image_resolution: int = 224, + vision_layers: Union[Tuple[int, int, int, int], int] = 12, + vision_heads: int = 12, + vision_embed_dim: int = 768, + vision_patch_size: int = 32, + vision_mlp_ratio: int = 4, + vision_hidden_act: str = "quick_gelu", + # text + max_text_length: int = 77, + vocab_size: int = 49408, + text_embed_dim: int = 512, + text_heads: int = 8, + text_layers: int = 12, + text_hidden_act: str = "quick_gelu", + # others + projection_dim: int = 512, + initializer_range: float = 0.02, + initializer_factor: float = 1.0, + logit_scale_init_value: float = 2.6592): + super().__init__() + self.initializer_factor = initializer_factor + self.initializer_range = initializer_range + self.logit_scale_init_value = logit_scale_init_value + self.vision_embed_dim = vision_embed_dim + self.text_embed_dim = text_embed_dim + self.vision_layers = vision_layers + self.text_layers = text_layers + + if isinstance(vision_layers, (tuple, list)): + if vision_heads is None: + vision_heads = vision_embed_dim * 32 // 64 + self.vision_model = ModifiedResNet( + layers=vision_layers, + output_dim=projection_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_embed_dim) + else: + if vision_heads is None: + vision_heads = vision_embed_dim // 64 + self.vision_model = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_embed_dim, + layers=vision_layers, + heads=vision_heads, + activation=vision_hidden_act, + mlp_ratio=vision_mlp_ratio, + normalize_before=True) + self.vision_projection = paddle.create_parameter( + (vision_embed_dim, projection_dim), paddle.get_default_dtype()) + + self.text_model = TextTransformer(context_length=max_text_length, + transformer_width=text_embed_dim, + transformer_heads=text_heads, + transformer_layers=text_layers, + vocab_size=vocab_size, + activation=text_hidden_act, + normalize_before=True) + + self.text_projection = paddle.create_parameter( + (text_embed_dim, projection_dim), paddle.get_default_dtype()) + + self.logit_scale = paddle.create_parameter( + (1, ), + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(logit_scale_init_value)) + self.apply(self._init_weights) + + def get_image_features(self, + pixel_values=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False): + r""" + Returns: + image_features (`paddle.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`VisionTransformer`]. + + Examples: + .. code-block:: + + import requests + from PIL import Image + from paddlenlp.transformers import CLIPProcessor, CLIPModel + + model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + inputs = processor(images=image, return_tensors="pd") + image_features = model.get_image_features(**inputs) + + """ + if isinstance(self.vision_model, ModifiedResNet): + return self.vision_model(pixel_values) + else: + vision_outputs = self.vision_model( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + pooled_output = vision_outputs[1] + image_features = paddle.matmul(pooled_output, + self.vision_projection) + return image_features + + def get_text_features( + self, + input_ids, + attention_mask=None, + position_ids=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + ): + r""" + Returns: + text_features (`paddle.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`TextTransformer`]. + + Example: + .. code-block:: + + from paddlenlp.transformers import CLIPModel, CLIPTokenizer + + model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pd") + text_features = model.get_text_features(**inputs) + + """ + text_outputs = self.text_model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + pooled_output = text_outputs[1] + text_features = paddle.matmul(pooled_output, self.text_projection) + return text_features + + def forward(self, + input_ids, + pixel_values, + attention_mask=None, + position_ids=None, + return_loss=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False): + r''' + The CLIPModel forward method, overrides the `__call__()` special method. + + Args: + input_ids (Tensor): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. + Its data type should be `int64` and it has a shape of [text_batch_size, sequence_length]. + pixel_values (Tensor): + Pixel values. Padding will be ignored by default should you provide it. + Its data type should be `float32` and it has a shape of [image_batch_size, num_channels, height, width]. + position_ids(Tensor, optional): + Indices of positions of each input sequence tokens in the position embeddings (TextTransformer). Selected in + the range ``[0, max_text_length - 1]``. + Shape as `(batch_size, num_tokens)` and dtype as int64. Defaults to `None`. + attention_mask (Tensor, optional): + Mask used in multi-head attention (TextTransformer) to avoid performing attention on to some unwanted positions, + usually the paddings or the subsequent positions. + Its data type can be int, float and bool. + When the data type is bool, the `masked` tokens have `False` values and the others have `True` values. + When the data type is int, the `masked` tokens have `0` values and the others have `1` values. + When the data type is float, the `masked` tokens have `0.0` values and the others have `1.0` values. + It is a tensor with shape `[batch_size, sequence_length`. + Defaults to `None`, which means nothing needed to be prevented attention to. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`CLIPOutput` object. If `False`, the output + will be a tuple of tensors. Defaults to `False`. + + Returns: + An instance of :class:`CLIPOutput` if `return_dict=True`. Otherwise it returns a tuple of tensors + corresponding to ordered and not None (depending on the input arguments) fields of :class:`CLIPOutput`. + + Example: + .. code-block:: + + import requests + import paddle.nn.functional as F + from PIL import Image + from paddlenlp.transformers import CLIPModel, CLIPProcessor + + processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32') + model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32') + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + inputs = processor(text=["a photo of a cat", "a photo of a dog"], + images=image, + padding=True, + return_tensors="pd") + + outputs = model(**inputs) + + logits_per_image = outputs[0] + probs = F.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities + + ''' + if isinstance(self.vision_model, ModifiedResNet): + vision_outputs = None + image_embeds = self.vision_model(pixel_values) + else: + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeds = paddle.matmul(vision_outputs[1], + self.vision_projection) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + text_embeds = paddle.matmul(text_outputs[1], self.text_projection) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(axis=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(axis=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = paddle.matmul( + text_embeds, image_embeds, transpose_y=True) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + + if return_loss: + loss = clip_loss(logits_per_text) + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, + image_embeds, text_outputs, vision_outputs) + return ((loss, ) + output) if loss is not None else output + + return CLIPOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +class CLIPTextPretrainedModel(CLIPPretrainedModel): + pass + + +@register_base_model +class CLIPTextModel(CLIPTextPretrainedModel): + r""" + The bare CLIPTextModel outputting :class:`BaseModelOutputWithPoolingAndCrossAttentions`. + This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`. + Refer to the superclass documentation for the generic methods. + This model is also a Paddle `paddle.nn.Layer `__ subclass. Use it as a regular Paddle Layer + and refer to the Paddle documentation for all matter related to general usage and behavior. + + Args: + max_text_length (int, optional): + The maximum value of the dimensionality of text position encoding, which dictates the maximum supported length of the text + input sequence. Defaults to `64`. + vocab_size (int, optional): + Vocabulary size of `inputs_ids` in `CLIPModel`. Also is the vocab size of text token embedding matrix. + Defaults to `49408`. + text_embed_dim (int, optional): + Dimensionality of the embedding layer and encoder layers in text model. + Defaults to `768`. + text_heads (int, optional): + Number of attention heads for each attention layer in the text attention. + Defaults to `8`. + text_layers (int, optional): + Number of hidden layers in the text model. + Defaults to `12`. + text_hidden_act (str, optional): + The non-linear activation function of the ffn layer in the text model. + ``"gelu"``, ``"relu"``, ``"quick_gelu"`` and any other paddle supported activation functions are supported. + Defaults to `"quick_gelu"`. + initializer_range (float, optional): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + Default to `0.02`. + initializer_factor (float, optional): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). Default to `1.`. + + """ + + def __init__(self, + max_text_length=77, + text_embed_dim=512, + text_heads=8, + text_layers=12, + vocab_size=49408, + text_hidden_act="quick_gelu", + initializer_range=0.02, + initializer_factor=1.0, + **kwargs): + super().__init__() + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.text_embed_dim = text_embed_dim + self.text_layers = text_layers + self.text_model = TextTransformer(context_length=max_text_length, + transformer_width=text_embed_dim, + transformer_heads=text_heads, + transformer_layers=text_layers, + vocab_size=vocab_size, + activation=text_hidden_act, + normalize_before=True) + self.apply(self._init_weights) + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + ): + r""" + The CLIPTextModel forward method, overrides the `__call__()` special method. + + Args: + input_ids (Tensor): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. + Its data type should be `int64` and it has a shape of [text_batch_size, sequence_length]. + attention_mask (Tensor, optional): + Mask used in multi-head attention (TextTransformer) to avoid performing attention on to some unwanted positions, + usually the paddings or the subsequent positions. + Its data type can be int, float and bool. + When the data type is bool, the `masked` tokens have `False` values and the others have `True` values. + When the data type is int, the `masked` tokens have `0` values and the others have `1` values. + When the data type is float, the `masked` tokens have `0.0` values and the others have `1.0` values. + It is a tensor with shape `[batch_size, sequence_length`. + Defaults to `None`, which means nothing needed to be prevented attention to. + position_ids(Tensor, optional): + Indices of positions of each input sequence tokens in the position embeddings (TextTransformer). Selected in + the range ``[0, max_text_length - 1]``. + Shape as `(batch_size, num_tokens)` and dtype as int64. Defaults to `None`. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`BaseModelOutputWithPoolingAndCrossAttentions` object. If `False`, the output + will be a tuple of tensors. Defaults to `False`. + + Returns: + An instance of :class:`BaseModelOutputWithPoolingAndCrossAttentions` if `return_dict=True`. Otherwise it returns a tuple of tensors + corresponding to ordered and not None (depending on the input arguments) fields of :class:`BaseModelOutputWithPoolingAndCrossAttentions`. + + Example: + .. code-block:: + + from paddlenlp.transformers import CLIPTokenizer, CLIPTextModel + + model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pd") + outputs = model(**inputs) + last_hidden_state = outputs.last_hidden_state + pooled_output = outputs.pooler_output + + """ + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class CLIPVisionPretrainedModel(CLIPPretrainedModel): + pass + + +@register_base_model +class CLIPVisionModel(CLIPVisionPretrainedModel): + r""" + The bare CLIPVisionModel outputting :class:`BaseModelOutputWithPoolingAndCrossAttentions`. + This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`. + Refer to the superclass documentation for the generic methods. + This model is also a Paddle `paddle.nn.Layer `__ subclass. Use it as a regular Paddle Layer + and refer to the Paddle documentation for all matter related to general usage and behavior. + + Args: + image_resolution (int, optional): + The size (resolution) of each image. + Defaults to `224`. + vision_layers (int, optional): + Number of hidden layers in the vision model. + Defaults to `12`. + vision_heads (int, optional): + Number of attention heads for each attention layer in the vision attention. + Defaults to `12`. + vision_embed_dim (int, optional): + Dimensionality of the embedding layer and encoder layers in vision model. + Defaults to `768`. + vision_patch_size(int, optional): + The size (resolution) of each patch. + Defaults to `32`. + vision_mlp_ratio(int, optional): + The ratio between dim_feedforward and vision_hidden_dim. `radio = dim_feedforward/vision_hidden_dim` + Defaults to `4`. + vision_hidden_act (str, optional): + The non-linear activation function of the ffn layer in the vision model. + ``"gelu"``, ``"relu"``, ``"quick_gelu"`` and any other paddle supported activation functions are supported. + Defaults to `"quick_gelu"`. + initializer_range (float, optional): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + Default to `0.02`. + initializer_factor (float, optional): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). Default to `1.`. + + """ + + def __init__(self, + image_resolution=224, + vision_patch_size=32, + vision_embed_dim=768, + vision_layers=12, + vision_heads=12, + vision_hidden_act="quick_gelu", + vision_mlp_ratio=4, + initializer_range=0.02, + initializer_factor=1.0, + **kwargs): + super().__init__() + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.vision_embed_dim = vision_embed_dim + self.vision_layers = vision_layers + + if vision_heads is None: + vision_heads = vision_embed_dim // 64 + self.vision_model = VisionTransformer(input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_embed_dim, + layers=vision_layers, + heads=vision_heads, + activation=vision_hidden_act, + mlp_ratio=vision_mlp_ratio, + normalize_before=True) + + self.apply(self._init_weights) + + def forward( + self, + pixel_values=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + ): + r''' + The CLIPVisionModel forward method, overrides the `__call__()` special method. + + Args: + pixel_values (Tensor): + Pixel values. Padding will be ignored by default should you provide it. + Its data type should be `float32` and it has a shape of [image_batch_size, num_channels, height, width]. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`BaseModelOutputWithPoolingAndCrossAttentions` object. If `False`, the output + will be a tuple of tensors. Defaults to `False`. + + Returns: + An instance of :class:`BaseModelOutputWithPoolingAndCrossAttentions` if `return_dict=True`. Otherwise it returns a tuple of tensors + corresponding to ordered and not None (depending on the input arguments) fields of :class:`BaseModelOutputWithPoolingAndCrossAttentions`. + + Example: + .. code-block:: + + from PIL import Image + import requests + from paddlenlp.transformers import CLIPProcessor, CLIPVisionModel + + model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") + processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + inputs = processor(images=image, return_tensors="pd") + outputs = model(**inputs) + + last_hidden_state = outputs.last_hidden_state + pooled_output = outputs.pooler_output # pooled CLS states + + ''' + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class CLIPForImageGeneration(CLIPPretrainedModel, DiscoDiffusionMixin, + StableDiffusionMixin): + r""" + CLIP Model with diffusion model on top. + + Args: + clip (:class:`CLIPModel`): + An instance of CLIPModel. + diffusion_type (str, optional): + The type of diffusion. Please choose in ['disco', 'stable']. + Defaults to `disco`. + scheduler_type (str, optional): + The type of scheduler. Please choose in ['pndm', 'ddim', 'k-lms']. + Defaults to `pndm`. + """ + + def __init__(self, clip, diffusion_type="disco", scheduler_type="pndm"): + super().__init__() + self.clip = clip + self.diffusion_type = diffusion_type + + if diffusion_type == "disco": + self.unet_model = create_unet_model( + image_size=512, + num_channels=256, + num_res_blocks=2, + channel_mult="", + learn_sigma=True, + class_cond=False, + attention_resolutions='32, 16, 8', + num_heads=4, + num_head_channels=64, + num_heads_upsample=-1, + use_scale_shift_norm=True, + dropout=0.0, + resblock_updown=True, + use_new_attention_order=False, + ) + self.secondary_model = create_secondary_model() + + elif diffusion_type == "stable": + del self.clip.vision_model + self.vae_model = AutoencoderKL( + in_channels=3, + out_channels=3, + down_block_types=( + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ), + up_block_types=( + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ), + block_out_channels=(128, 256, 512, 512), + layers_per_block=2, + act_fn="silu", + latent_channels=4, + sample_size=512, + ) + if scheduler_type == "pndm": + self.scheduler = PNDMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + skip_prk_steps=True, + ) + elif scheduler_type == "ddim": + self.scheduler = DDIMScheduler(beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False) + elif scheduler_type == "k-lms": + self.scheduler = LMSDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear") + else: + raise ValueError( + 'scheduler_type must be in ["pndm", "ddim", "k-lms"]') + + self.unet_model = UNet2DConditionModel( + sample_size=64, + in_channels=4, + out_channels=4, + center_input_sample=False, + flip_sin_to_cos=True, + freq_shift=0, + down_block_types=( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + ), + block_out_channels=(320, 640, 1280, 1280), + layers_per_block=2, + downsample_padding=1, + mid_block_scale_factor=1, + act_fn="silu", + norm_num_groups=32, + norm_eps=1e-5, + cross_attention_dim=768, + attention_head_dim=8, + ) + # input_ids_uncond + # [49406, 49407, 49407, 49407, 49407,...,49407] + input_ids_uncond = [ + 49406, 49407 + ] + [49407] * (clip.config["max_text_length"] - 2) + self.register_buffer("input_ids_uncond", + paddle.to_tensor([input_ids_uncond], + dtype="int64"), + persistable=False) + else: + raise ValueError( + "diffusion_type: Please choose in ['disco', 'stable']") + + # eval mode and stop all param's gradient + self.eval() + for param in self.parameters(): + param.stop_gradient = True + + def generate(self, *args, **kwargs): + if self.diffusion_type == "disco": + return self.disco_diffusion_generate(*args, **kwargs) + else: + return self.stable_diffusion_generate(*args, **kwargs) + + def stable_diffusion_generate(self, + input_ids, + mode="text2image", + init_image=None, + mask_image=None, + seed=None, + strength=0.8, + height=512, + width=512, + num_inference_steps=50, + guidance_scale=7.5, + eta=0.0, + latents=None, + fp16=True, + **kwargs): + r""" + The CLIPForImageGeneration stable_diffusion_generate method. + + Args: + input_ids (Tensor): + See :class:`CLIPModel`. + mode (str, optional): + The mode of StableDiffusion. Support ["text2image", "image2image", "inpaint"]. + init_image (Union[str, PIL.Image.Image], optional): + In `"image2image"` or `"inpaint"` mode, we must input the `init_image`. + Used in `"image2image"` and `"inpaint"` mode. + Default to `None`. + mask_image (Union[str, PIL.Image], optional): + In `"inpaint"` mode, we must input the `mask_image`. + Used in `"inpaint"` mode. + Default to `None`. + seed (int, optional): + A random number seed which is used as the basis for determining the initial. + Default to `None`. + strength (float, optional): + strength is a value between 0.0 and 1.0, that controls the amount of noise that is + added to the input image. Values that approach 1.0 allow for lots of variations but + will also produce images that are not semantically consistent with the input. + Used in `"image2image"` and `"inpaint"` mode. + Default to `0.8`. + height (int, optional): + The height of the image you want to generate, in pixels. + `height` have to be divisible by 64. Used in `"text2image"` mode. + Default to `512`. + width (int, optional): + The height of the image you want to generate, in pixels. + `width` have to be divisible by 64. Used in `"text2image"` mode. + Default to `512`. + num_inference_steps (int, optional): + Indicates the number of steps of inference. Generally speaking, the more steps, + the better the result, but the more steps, the longer the generation time. + Stable Diffusion gives good results with relatively few steps, so the default + value is set to 50. If you want faster speed, you can reduce this value, + if you want to generate better results, you can increase this value. + Default to `50`. + guidance_scale (float, optional): + `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + corresponds to doing no classifier free guidance. + Default to `7.5`. + eta (float, optional): + eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502. + Default to `0.0`. + latents (paddle.Tensor, optional): + We can specify the latents. latents_shape should be `[batch_size, unet_model.in_channels, height // 8, width // 8]` + Default to `None`. + fp16 (bool, optional): + Whether to use fp16 for inference. + Default to `True`. + + """ + assert input_ids is not None, "text2image/image2image/inpaint, please specify `input_ids`" + if mode == "text2image": + return self.stable_diffusion_text2image( + input_ids=input_ids, + seed=seed, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + eta=eta, + latents=latents, + fp16=fp16) + elif mode == "image2image": + assert init_image is not None, "image2image mode, please specify `init_image`" + # preprocess image + init_image = self.preprocess_image(init_image) + return self.stable_diffusion_image2image( + input_ids=input_ids, + init_image=init_image, + strength=strength, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + eta=eta, + seed=seed, + fp16=fp16) + + elif mode == "inpaint": + assert init_image is not None, "inpaint mode, please specify `init_image`" + assert mask_image is not None, "inpaint mode, please specify `mask_image`" + # preprocess image + init_image = self.preprocess_image(init_image) + # preprocess mask + mask_image = self.preprocess_mask(mask_image) + + return self.stable_diffusion_inpainting( + input_ids=input_ids, + init_image=init_image, + mask_image=mask_image, + strength=strength, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + eta=eta, + seed=seed, + fp16=fp16) + else: + raise ValueError( + 'Mode must be in ["text2image", "image2image", "inpaint"]') + + def disco_diffusion_generate( + self, + input_ids, + attention_mask=None, + position_ids=None, + init_image=None, + output_dir='disco_diffusion_clip_vitb32_out/', + width_height=[1280, 768], + skip_steps=0, + steps=250, + cut_ic_pow=1, + init_scale=1000, + clip_guidance_scale=5000, + tv_scale=0, + range_scale=0, + sat_scale=0, + cutn_batches=4, + perlin_init=False, + perlin_mode='mixed', + seed=None, + eta=0.8, + clamp_grad=True, + clamp_max=0.05, + cut_overview='[12]*400+[4]*600', + cut_innercut='[4]*400+[12]*600', + cut_icgray_p='[0.2]*400+[0]*600', + save_rate=10, + n_batches=1, + batch_name="", + use_secondary_model=True, + randomize_class=True, + clip_denoised=False, + ): + r""" + The CLIPForImageGeneration disco_diffusion_generate method. + + Args: + input_ids (Tensor): + See :class:`CLIPModel`. + attention_mask (Tensor, optional): + See :class:`CLIPModel`. + position_ids (Tensor, optional): + See :class:`CLIPModel`. + init_image (Path, optional): + Recall that in the image sequence above, the first image shown is just noise. If an init_image + is provided, diffusion will replace the noise with the init_image as its starting state. To use + an init_image, upload the image to the Colab instance or your Google Drive, and enter the full + image path here. If using an init_image, you may need to increase skip_steps to ~ 50% of total + steps to retain the character of the init. See skip_steps above for further discussion. + Default to `None`. + output_dir (Path, optional): + Output directory. + Default to `disco_diffusion_clip_vitb32_out`. + width_height (List[int, int], optional): + Desired final image size, in pixels. You can have a square, wide, or tall image, but each edge + length should be set to a multiple of 64px, and a minimum of 512px on the default CLIP model setting. + If you forget to use multiples of 64px in your dimensions, DD will adjust the dimensions of your + image to make it so. + Default to `[1280, 768]`. + skip_steps (int, optional): + Consider the chart shown here. Noise scheduling (denoise strength) starts very high and progressively + gets lower and lower as diffusion steps progress. The noise levels in the first few steps are very high, + so images change dramatically in early steps.As DD moves along the curve, noise levels (and thus the + amount an image changes per step) declines, and image coherence from one step to the next increases. + The first few steps of denoising are often so dramatic that some steps (maybe 10-15% of total) can be + skipped without affecting the final image. You can experiment with this as a way to cut render times. + If you skip too many steps, however, the remaining noise may not be high enough to generate new content, + and thus may not have time left to finish an image satisfactorily.Also, depending on your other settings, + you may need to skip steps to prevent CLIP from overshooting your goal, resulting in blown out colors + (hyper saturated, solid white, or solid black regions) or otherwise poor image quality. Consider that + the denoising process is at its strongest in the early steps, so skipping steps can sometimes mitigate + other problems.Lastly, if using an init_image, you will need to skip ~50% of the diffusion steps to retain + the shapes in the original init image. However, if you're using an init_image, you can also adjust + skip_steps up or down for creative reasons. With low skip_steps you can get a result "inspired by" + the init_image which will retain the colors and rough layout and shapes but look quite different. + With high skip_steps you can preserve most of the init_image contents and just do fine tuning of the texture. + Default to `10`. + steps: + When creating an image, the denoising curve is subdivided into steps for processing. Each step (or iteration) + involves the AI looking at subsets of the image called 'cuts' and calculating the 'direction' the image + should be guided to be more like the prompt. Then it adjusts the image with the help of the diffusion denoiser, + and moves to the next step.Increasing steps will provide more opportunities for the AI to adjust the image, + and each adjustment will be smaller, and thus will yield a more precise, detailed image. Increasing steps + comes at the expense of longer render times. Also, while increasing steps should generally increase image + quality, there is a diminishing return on additional steps beyond 250 - 500 steps. However, some intricate + images can take 1000, 2000, or more steps. It is really up to the user. Just know that the render time is + directly related to the number of steps, and many other parameters have a major impact on image quality, without + costing additional time. + cut_ic_pow (int, optional): + This sets the size of the border used for inner cuts. High cut_ic_pow values have larger borders, and + therefore the cuts themselves will be smaller and provide finer details. If you have too many or too-small + inner cuts, you may lose overall image coherency and/or it may cause an undesirable 'mosaic' effect. + Low cut_ic_pow values will allow the inner cuts to be larger, helping image coherency while still helping + with some details. + Default to `1`. + init_scale (int, optional): + This controls how strongly CLIP will try to match the init_image provided. This is balanced against the + clip_guidance_scale (CGS) above. Too much init scale, and the image won't change much during diffusion. + Too much CGS and the init image will be lost. + Default to `1000`. + clip_guidance_scale (int, optional): + CGS is one of the most important parameters you will use. It tells DD how strongly you want CLIP to move + toward your prompt each timestep. Higher is generally better, but if CGS is too strong it will overshoot + the goal and distort the image. So a happy medium is needed, and it takes experience to learn how to adjust + CGS. Note that this parameter generally scales with image dimensions. In other words, if you increase your + total dimensions by 50% (e.g. a change from 512 x 512 to 512 x 768), then to maintain the same effect on the + image, you'd want to increase clip_guidance_scale from 5000 to 7500. Of the basic settings, clip_guidance_scale, + steps and skip_steps are the most important contributors to image quality, so learn them well. + Default to `5000`. + tv_scale (int, optional): + Total variance denoising. Optional, set to zero to turn off. Controls smoothness of final output. If used, + tv_scale will try to smooth out your final image to reduce overall noise. If your image is too 'crunchy', + increase tv_scale. TV denoising is good at preserving edges while smoothing away noise in flat regions. + See https://en.wikipedia.org/wiki/Total_variation_denoising + Default to `0`. + range_scale (int, optional): + Optional, set to zero to turn off. Used for adjustment of color contrast. Lower range_scale will increase + contrast. Very low numbers create a reduced color palette, resulting in more vibrant or poster-like images. + Higher range_scale will reduce contrast, for more muted images. + Default to `0`. + sat_scale (int, optional): + Saturation scale. Optional, set to zero to turn off. If used, sat_scale will help mitigate oversaturation. + If your image is too saturated, increase sat_scale to reduce the saturation. + Default to `0`. + cutn_batches (int, optional): + Each iteration, the AI cuts the image into smaller pieces known as cuts, and compares each cut to the prompt + to decide how to guide the next diffusion step. More cuts can generally lead to better images, since DD has + more chances to fine-tune the image precision in each timestep. Additional cuts are memory intensive, however, + and if DD tries to evaluate too many cuts at once, it can run out of memory. You can use cutn_batches to increase + cuts per timestep without increasing memory usage. At the default settings, DD is scheduled to do 16 cuts per + timestep. If cutn_batches is set to 1, there will indeed only be 16 cuts total per timestep. However, if + cutn_batches is increased to 4, DD will do 64 cuts total in each timestep, divided into 4 sequential batches + of 16 cuts each. Because the cuts are being evaluated only 16 at a time, DD uses the memory required for only 16 cuts, + but gives you the quality benefit of 64 cuts. The tradeoff, of course, is that this will take ~4 times as long to + render each image.So, (scheduled cuts) x (cutn_batches) = (total cuts per timestep). Increasing cutn_batches will + increase render times, however, as the work is being done sequentially. DD's default cut schedule is a good place + to start, but the cut schedule can be adjusted in the Cutn Scheduling section, explained below. + Default to `4`. + perlin_init (bool, optional): + Normally, DD will use an image filled with random noise as a starting point for the diffusion curve. + If perlin_init is selected, DD will instead use a Perlin noise model as an initial state. Perlin has very + interesting characteristics, distinct from random noise, so it's worth experimenting with this for your projects. + Beyond perlin, you can, of course, generate your own noise images (such as with GIMP, etc) and use them as an + init_image (without skipping steps). Choosing perlin_init does not affect the actual diffusion process, just the + starting point for the diffusion. Please note that selecting a perlin_init will replace and override any init_image + you may have specified. Further, because the 2D, 3D and video animation systems all rely on the init_image system, + if you enable Perlin while using animation modes, the perlin_init will jump in front of any previous image or video + input, and DD will NOT give you the expected sequence of coherent images. All of that said, using Perlin and + animation modes together do make a very colorful rainbow effect, which can be used creatively. + Default to `False`. + perlin_mode (str, optional): + sets type of Perlin noise: colored, gray, or a mix of both, giving you additional options for noise types. Experiment + to see what these do in your projects. + Default to `mixed`. + seed (int, optional): + Deep in the diffusion code, there is a random number seed which is used as the basis for determining the initial + state of the diffusion. By default, this is random, but you can also specify your own seed. This is useful if you like a + particular result and would like to run more iterations that will be similar. After each run, the actual seed value used will be + reported in the parameters report, and can be reused if desired by entering seed # here. If a specific numerical seed is used + repeatedly, the resulting images will be quite similar but not identical. + Default to `None`. + eta (float, optional): + Eta (greek letter η) is a diffusion model variable that mixes in a random amount of scaled noise into each timestep. + 0 is no noise, 1.0 is more noise. As with most DD parameters, you can go below zero for eta, but it may give you + unpredictable results. The steps parameter has a close relationship with the eta parameter. If you set eta to 0, + then you can get decent output with only 50-75 steps. Setting eta to 1.0 favors higher step counts, ideally around + 250 and up. eta has a subtle, unpredictable effect on image, so you'll need to experiment to see how this affects your projects. + Default to `0.8`. + clamp_grad (bool, optional): + As I understand it, clamp_grad is an internal limiter that stops DD from producing extreme results. Try your images with and without + clamp_grad. If the image changes drastically with clamp_grad turned off, it probably means your clip_guidance_scale is too high and + should be reduced. + Default to `True`. + clamp_max (float, optional): + Sets the value of the clamp_grad limitation. Default is 0.05, providing for smoother, more muted coloration in images, but setting + higher values (0.15-0.3) can provide interesting contrast and vibrancy. + Default to `0.05`. + cut_overview (str, optional): + The schedule of overview cuts. + Default to `'[12]*400+[4]*600'`. + cut_innercut (str, optional): + The schedule of inner cuts. + Default to `'[4]*400+[12]*600'`. + cut_icgray_p (str, optional): + This sets the size of the border used for inner cuts. High cut_ic_pow values have larger borders, and therefore the cuts + themselves will be smaller and provide finer details. If you have too many or too-small inner cuts, you may lose overall + image coherency and/or it may cause an undesirable 'mosaic' effect. Low cut_ic_pow values will allow the inner cuts to be + larger, helping image coherency while still helping with some details. + Default to `'[0.2]*400+[0]*600'`. + save_rate (int, optional): + During a diffusion run, you can monitor the progress of each image being created with this variable. If display_rate is set + to 50, DD will show you the in-progress image every 50 timesteps. Setting this to a lower value, like 5 or 10, is a good way + to get an early peek at where your image is heading. If you don't like the progression, just interrupt execution, change some + settings, and re-run. If you are planning a long, unmonitored batch, it's better to set display_rate equal to steps, because + displaying interim images does slow Colab down slightly. + Default to `10`. + n_batches (int, optional): + This variable sets the number of still images you want DD to create. If you are using an animation mode (see below for details) + DD will ignore n_batches and create a single set of animated frames based on the animation settings. + Default to `1`. + batch_name (str, optional): + The name of the batch, the batch id will be named as "progress-[batch_name]-seed-[range(n_batches)]-[save_rate]". To avoid your + artworks be overridden by other users, please use a unique name. + Default to `''`. + use_secondary_model (bool, optional): + Whether or not use secondary model. + Default to `True`. + randomize_class (bool, optional): + Random class. + Default to `True`. + clip_denoised (bool, optional): + Clip denoised. + Default to `False`. + + Returns: + List[PIL.Image]: Returns n_batches of final image. + Its data type should be PIL.Image. + + Example: + .. code-block:: + + from paddlenlp.transformers import CLIPForImageGeneration, CLIPTokenizer + + # Initialize the model and tokenizer + model_name_or_path = 'openai/disco-diffusion-clip-vit-base-patch32' + model = CLIPForImageGeneration.from_pretrained(model_name_or_path) + tokenizer = CLIPTokenizer.from_pretrained(model_name_or_path) + model.eval() + + # Prepare the text_prompt. + text_prompt = "A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation." + # Style and artist can be specified + style = None + artist = None + text_prompt = model.preprocess_text_prompt(text_prompt, style=style, artist=artist) + + # CLIP's pad_token_id is 0 and padding to max length 77 (tokenizer.model_max_length). + raw_pad_token_id = tokenizer.pad_token_id + tokenizer.pad_token_id = 0 + tokenized_inputs = tokenizer(text_prompt, return_tensors="pd", padding="max_length", max_length=tokenizer.model_max_length) + images = model.generate(**tokenized_inputs) + + # return List[PIL.Image] + images[0].save("figure.png") + + """ + self.diffusion = create_gaussian_diffusion( + steps=steps, + learn_sigma=True, + sigma_small=False, + noise_schedule="linear", + predict_xstart=False, + rescale_timesteps=True, + ) + # get get_text_features + target_text_embeds = self.get_text_features( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids) + + images_list = super().disco_diffusion_generate( + target_text_embeds=target_text_embeds, + clamp_grad=clamp_grad, + clamp_max=clamp_max, + clip_denoised=clip_denoised, + clip_guidance_scale=clip_guidance_scale, + cut_ic_pow=cut_ic_pow, + cut_icgray_p=cut_icgray_p, + cut_innercut=cut_innercut, + cut_overview=cut_overview, + cutn_batches=cutn_batches, + save_rate=save_rate, + eta=eta, + init_image=init_image, + init_scale=init_scale, + n_batches=n_batches, + output_dir=output_dir, + perlin_init=perlin_init, + perlin_mode=perlin_mode, + randomize_class=randomize_class, + range_scale=range_scale, + sat_scale=sat_scale, + seed=seed, + skip_steps=skip_steps, + tv_scale=tv_scale, + use_secondary_model=use_secondary_model, + width_height=width_height, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + batch_name=batch_name) + + return images_list + + def preprocess_text_prompt(self, text_prompt, style=None, artist=None): + text_prompt = text_prompt.rstrip(',.,。') + if style is not None: + text_prompt += ",{}".format(style) + if artist is not None: + text_prompt += ",{},trending on artstation".format(artist) + return text_prompt + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError as e: + try: + return getattr(getattr(self, self.base_model_prefix), name) + except AttributeError: + try: + return getattr(self, self.base_model_prefix).config[name] + except KeyError: + raise e diff --git a/paddlenlp/transformers/clip/procesing.py b/paddlenlp/transformers/clip/procesing.py new file mode 100644 index 000000000000..2679b642b801 --- /dev/null +++ b/paddlenlp/transformers/clip/procesing.py @@ -0,0 +1,119 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for CLIP +""" + +from ..tokenizer_utils_base import BatchEncoding +from .tokenizer import CLIPTokenizer +from .feature_extraction import CLIPFeatureExtractor + +__all__ = ["CLIPProcessor"] + + +class CLIPProcessor(object): + r""" + Constructs a CLIP processor which wraps a CLIP feature extractor and a CLIP tokenizer into a single processor. + [`CLIPProcessor`] offers all the functionalities of [`CLIPFeatureExtractor`] and [`CLIPTokenizer`]. See the + [`CLIPProcessor.__call__`] and [`CLIPProcessor.decode`] for more information. + Args: + feature_extractor ([`CLIPFeatureExtractor`]): + The feature extractor is a required input. + tokenizer ([`CLIPTokenizer`]): + The tokenizer is a required input. + """ + + def __init__(self, feature_extractor, tokenizer): + super().__init__() + self.tokenizer = tokenizer + self.feature_extractor = feature_extractor + + def __call__(self, text=None, images=None, return_tensors=None, **kwargs): + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to CLIPTokenizer's [`CLIPTokenizer.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPFeatureExtractor's [`CLIPFeatureExtractor.__call__`] if `images` is not `None`. Please refer to the + doctsring of the above two methods for more information. + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `paddle.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[paddle.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or Paddle + tensor. In case of a NumPy array/Paddle tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'np'`: Return NumPy `np.ndarray` objects. + - `'pd'`: Return Paddle `paddle.Tensor` objects. + Returns: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + + if text is None and images is None: + raise ValueError( + "You have to specify either text or images. Both cannot be none." + ) + + if text is not None: + encoding = self.tokenizer(text, + return_tensors=return_tensors, + **kwargs) + + if images is not None: + image_features = self.feature_extractor( + images, return_tensors=return_tensors, **kwargs) + + if text is not None and images is not None: + encoding["pixel_values"] = image_features.pixel_values + return encoding + elif text is not None: + return encoding + else: + return BatchEncoding(data=dict(**image_features), + tensor_type=return_tensors) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + # TODO junnyu find a better way from_pretrained and save_pretrained + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, + *args, **kwargs) + feature_extractor = CLIPFeatureExtractor() + return cls(feature_extractor, tokenizer) + + def save_pretrained(self, save_directory, filename_prefix=None, **kwargs): + return self.tokenizer.save_pretrained(save_directory, filename_prefix, + **kwargs) diff --git a/paddlenlp/transformers/clip/tokenizer.py b/paddlenlp/transformers/clip/tokenizer.py new file mode 100644 index 000000000000..4e5d384dc9d3 --- /dev/null +++ b/paddlenlp/transformers/clip/tokenizer.py @@ -0,0 +1,419 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2021 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.utils import try_import +import os +import shutil +from .. import PretrainedTokenizer, AddedToken, BasicTokenizer +from ...utils.log import logger +from functools import lru_cache +import json + +__all__ = ['CLIPTokenizer'] + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = (list(range(ord("!"), + ord("~") + 1)) + list(range(ord("¡"), + ord("¬") + 1)) + + list(range(ord("®"), + ord("ÿ") + 1))) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def whitespace_clean(text, re): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class CLIPTokenizer(PretrainedTokenizer): + r""" + Construct a CLIP tokenizer based on byte-level Byte-Pair-Encoding. + + This tokenizer inherits from :class:`~paddlenlp.transformers.gpt.tokenizer.GPTTokenizer`. + For more information regarding those methods, please refer to this superclass. + + Args: + vocab_file (str): + Path to the vocabulary file. + The vocab file contains a mapping from vocabulary strings to indices. + merges_file (str): + Path to the merge file. + The merge file is used to split the input sentence into "subword" units. + The vocab file is then used to encode those units as intices. + errors (str): + Paradigm to follow when decoding bytes to UTF-8. + Defaults to `'replace'`. + max_len (int, optional): + The maximum value of the input sequence length. + Defaults to `77`. + bos_token (str, optional): + The beginning of sequence token that was used during pretraining. Can be + used a sequence classifier token. + Defaults to `"<|startoftext|>"`. + eos_token (str, optional): + A special token representing the end of a sequence that was used during pretraining. + Defaults to `"<|endoftext|>"`. + unk_token (str, optional): + A special token representing the *unknown (out-of-vocabulary)* token. + An unknown token is set to be `unk_token` inorder to be converted to an ID. + Defaults to `"<|endoftext|>"`. + pad_token (str, optional): + A special token used to make arrays of tokens the same size for batching purposes. + Defaults to `"<|endoftext|>"`. + + Examples: + .. code-block:: + + from paddlenlp.transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32') + print(tokenizer('He was a puppeteer')) + + ''' + {'input_ids': [49406, 797, 739, 320, 7116, 38820, 528, 49407]} + ''' + + """ + # merges and vocab same as GPT2 + resource_files_names = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt" + } + pretrained_resource_files_map = { + "vocab_file": { + "openai/clip-vit-base-patch32": + "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-vit-base-patch32/vocab.json", + "openai/clip-rn50": + "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-rn50/vocab.json", + "openai/clip-rn101": + "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-rn101/vocab.json", + "openai/clip-vit-large-patch14": + "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-vit-large-patch14/vocab.json", + }, + "merges_file": { + "openai/clip-vit-base-patch32": + "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-vit-base-patch32/merges.txt", + "openai/clip-rn50": + "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-rn50/merges.txt", + "openai/clip-rn101": + "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-rn101/merges.txt", + "openai/clip-vit-large-patch14": + "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-vit-large-patch14/merges.txt", + } + } + pretrained_init_configuration = { + "openai/clip-vit-base-patch32": { + "max_len": 77 + }, + "openai/clip-rn50": { + "max_len": 77 + }, + "openai/clip-rn101": { + "max_len": 77 + }, + "openai/clip-vit-large-patch14": { + "max_len": 77 + }, + } + + def __init__(self, + vocab_file, + merges_file, + errors='replace', + max_len=77, + bos_token="<|startoftext|>", + eos_token="<|endoftext|>", + unk_token="<|endoftext|>", + pad_token="<|endoftext|>", + **kwargs): + + bos_token = AddedToken(bos_token, + lstrip=False, rstrip=False) if isinstance( + bos_token, str) else bos_token + eos_token = AddedToken(eos_token, + lstrip=False, rstrip=False) if isinstance( + eos_token, str) else eos_token + unk_token = AddedToken(unk_token, + lstrip=False, rstrip=False) if isinstance( + unk_token, str) else unk_token + pad_token = AddedToken(pad_token, + lstrip=False, rstrip=False) if isinstance( + pad_token, str) else pad_token + + self._build_special_tokens_map_extended(bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token) + + try: + import ftfy + self.fix_text = ftfy.fix_text + except ImportError: + logger.warning( + "ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy." + ) + self.nlp = BasicTokenizer(do_lower_case=True) + self.fix_text = None + self.re = try_import("regex") + + self._vocab_file = vocab_file + self._merges_file = merges_file + self.max_len = max_len if max_len is not None else int(1e12) + + with open(vocab_file, 'r', encoding='utf-8') as f: + self.encoder = json.load(f) + + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + + with open(merges_file, encoding="utf-8") as merges_handle: + bpe_merges = merges_handle.read().strip().split("\n")[1:49152 - + 256 - 2 + 1] + bpe_merges = [tuple(merge.split()) for merge in bpe_merges] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = { + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>" + } + + self.pat = self.re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + self.re.IGNORECASE, + ) + + @property + def vocab_size(self): + """ + Returns the size of vocabulary. + + Returns: + int: The sum of size of vocabulary and the size of speical tokens. + + """ + return len(self.encoder) + + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. A CLIP sequence has the following format: + + - single sequence: `<|startoftext|> X <|endoftext|>` + + Pairs of sequences are not the expected use case, but they will be handled without a separator. + + Args: + token_ids_0 (List[int]): + List of IDs to which the special tokens will be added. + token_ids_1 (List[int], optional): + Optional second list of IDs for sequence pairs. Defaults to None. + + Returns: + List[int]: List of input_id with the appropriate special tokens. + """ + _bos = [self.bos_token_id] + _eos = [self.eos_token_id] + if token_ids_1 is None: + return _bos + token_ids_0 + _eos + return _bos + token_ids_0 + _eos + _eos + token_ids_1 + _eos + + def get_special_tokens_mask(self, + token_ids_0, + token_ids_1=None, + already_has_special_tokens=False): + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is + called when adding special tokens using the tokenizer ``encode`` methods. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True) + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ( + [0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences(self, + token_ids_0, + token_ids_1=None): + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. + """ + eos = [self.eos_token_id] + bos = [self.bos_token_id] + + if token_ids_1 is None: + return len(bos + token_ids_0 + eos) * [0] + return len(bos + token_ids_0 + eos + eos + token_ids_1 + eos) * [0] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "", ) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + if self.fix_text is None: + text = " ".join(self.nlp.tokenize(text)) + else: + text = whitespace_clean(self.fix_text(text), self.re).lower() + + for token in self.re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token + for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def _convert_token_to_id(self, token): + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index): + return self.decoder[index] + + def convert_tokens_to_string(self, tokens): + """ + Converts a sequence of tokens (string) in a single string. + """ + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text + ]).decode('utf-8', + errors=self.errors).replace("", + " ").strip() + return text + + def save_resources(self, save_directory): + """ + Saves `SentencePiece `__ file + (ends with '.spm') under `save_directory`. + + Args: + save_directory (str): Directory to save files into. + """ + for name, file_name in self.resource_files_names.items(): + source_path = getattr(self, "_%s" % name) + + save_path = os.path.join(save_directory, file_name) + if os.path.abspath(source_path) != os.path.abspath(save_path): + shutil.copyfile(source_path, save_path) + + def __call__( + self, + text, + text_pair=None, + max_length=None, + stride=0, + is_split_into_words=False, + padding=False, + truncation=False, + return_position_ids=False, + return_token_type_ids=False, # don't return token_type_ids + return_attention_mask=False, + return_length=False, + return_overflowing_tokens=False, + return_special_tokens_mask=False, + return_dict=True, + return_offsets_mapping=False, + add_special_tokens=True, + pad_to_multiple_of=None, + return_tensors=None, + verbose: bool = True, + **kwargs): + return super().__call__( + text, text_pair, max_length, stride, is_split_into_words, padding, + truncation, return_position_ids, return_token_type_ids, + return_attention_mask, return_length, return_overflowing_tokens, + return_special_tokens_mask, return_dict, return_offsets_mapping, + add_special_tokens, pad_to_multiple_of, return_tensors, verbose, + **kwargs) diff --git a/paddlenlp/transformers/feature_extraction_utils.py b/paddlenlp/transformers/feature_extraction_utils.py new file mode 100644 index 000000000000..5e612ab3e19f --- /dev/null +++ b/paddlenlp/transformers/feature_extraction_utils.py @@ -0,0 +1,121 @@ +# coding=utf-8 +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from collections import UserDict +from typing import Any, Dict, Optional, Union + +import numpy as np +from .tokenizer_utils_base import TensorType + + +class BatchFeature(UserDict): + r""" + Holds the feature extractor specific `__call__` methods. + This class is derived from a python dictionary and can be used as a dictionary. + Args: + data (`dict`): + Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask', + etc.). + tensor_type (`Union[None, str, TensorType]`, *optional*): + You can give a tensor_type here to convert the lists of integers in Paddle/Numpy Tensors at + initialization. + """ + + def __init__(self, + data: Optional[Dict[str, Any]] = None, + tensor_type: Union[None, str, TensorType] = None): + super().__init__(data) + self.convert_to_tensors(tensor_type=tensor_type) + + def __getitem__(self, item: str): + """ + If the key is a string, returns the value of the dict associated to `key` ('input_values', 'attention_mask', + etc.). + """ + if isinstance(item, str): + return self.data[item] + else: + raise KeyError( + "Indexing with integers is not available when using Python based feature extractors" + ) + + def __getattr__(self, item: str): + try: + return self.data[item] + except KeyError: + raise AttributeError + + def __getstate__(self): + return {"data": self.data} + + def __setstate__(self, state): + if "data" in state: + self.data = state["data"] + + def keys(self): + return self.data.keys() + + def values(self): + return self.data.values() + + def items(self): + return self.data.items() + + def convert_to_tensors(self, + tensor_type: Optional[Union[str, + TensorType]] = None): + """ + Convert the inner content to tensors. + Args: + tensor_type (`str` or [`TensorType`], *optional*): + The type of tensors to use. If `str`, should be one of the values of the enum [`TensorType`]. If + `None`, no modification is done. + """ + if tensor_type is None: + return self + + # Convert to TensorType + if not isinstance(tensor_type, TensorType): + tensor_type = TensorType(tensor_type) + + # Get a function reference for the correct framework + if tensor_type == TensorType.PADDLE: + as_tensor = paddle.to_tensor + is_tensor = paddle.is_tensor + else: + as_tensor = np.asarray + is_tensor = lambda x: isinstance(x, np.ndarray) + + # Do the tensor conversion in batch + for key, value in self.items(): + try: + if not is_tensor(value): + tensor = as_tensor(value) + + self[key] = tensor + except: # noqa E722 + if key == "overflowing_tokens": + raise ValueError( + "Unable to create tensor returning overflowing tokens of different lengths. " + "Please see if a fast version of this tokenizer is available to have this feature available." + ) + raise ValueError( + "Unable to create tensor, you should probably activate truncation and/or padding " + "with 'padding=True' 'truncation=True' to have batched tensors with the same length." + ) + + return self diff --git a/paddlenlp/transformers/guided_diffusion_utils/__init__.py b/paddlenlp/transformers/guided_diffusion_utils/__init__.py new file mode 100644 index 000000000000..46c4d89f4cd9 --- /dev/null +++ b/paddlenlp/transformers/guided_diffusion_utils/__init__.py @@ -0,0 +1,2 @@ +from .model_diffusion import create_gaussian_diffusion, create_unet_model, create_secondary_model +from .utils import DiscoDiffusionMixin \ No newline at end of file diff --git a/paddlenlp/transformers/guided_diffusion_utils/gaussian_diffusion.py b/paddlenlp/transformers/guided_diffusion_utils/gaussian_diffusion.py new file mode 100755 index 000000000000..7736c42a6cef --- /dev/null +++ b/paddlenlp/transformers/guided_diffusion_utils/gaussian_diffusion.py @@ -0,0 +1,771 @@ +""" +Diffusion model implemented by Paddle. +This code is rewritten based on Pytorch version of of Ho et al's diffusion models: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py +""" +import enum +import math + +import numpy as np +import paddle + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace(beta_start, + beta_end, + num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2)**2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class GaussianDiffusion: + """ + Utilities for sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + rescale_timesteps=False, + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.rescale_timesteps = rescale_timesteps + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps, ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - + 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = (betas * (1.0 - self.alphas_cumprod_prev) / + (1.0 - self.alphas_cumprod)) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:])) + self.posterior_mean_coef1 = (betas * np.sqrt(self.alphas_cumprod_prev) / + (1.0 - self.alphas_cumprod)) + self.posterior_mean_coef2 = ((1.0 - self.alphas_cumprod_prev) * + np.sqrt(alphas) / + (1.0 - self.alphas_cumprod)) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * + x_start) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, + x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, + t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + # noise = th.randn_like(x_start) + noise = paddle.randn(x_start.shape, x_start.dtype) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * + x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) * noise) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * + x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, + x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == + posterior_log_variance_clipped.shape[0] == x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == [B] + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE + ]: + assert model_output.shape == [B, C * 2, *x.shape[2:]] + model_output, model_var_values = paddle.split(model_output, + 2, + axis=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = paddle.exp(model_log_variance) + else: + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = paddle.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], + self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, + x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)) + model_mean = model_output + elif self.model_mean_type in [ + ModelMeanType.START_X, ModelMeanType.EPSILON + ]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t) + else: + raise NotImplementedError(self.model_mean_type) + + assert (model_mean.shape == model_log_variance.shape == + pred_xstart.shape == x.shape) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * + x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, + x_t.shape) * eps) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) + * xprev - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, + x_t.shape) * x_t) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * + x_t - pred_xstart) / _extract_into_tensor( + self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return paddle.cast((t), 'float32') * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = (paddle.cast((p_mean_var["mean"]), 'float32') + + p_mean_var["variance"] * paddle.cast((gradient), 'float32')) + return new_mean + + def condition_mean_with_grad(self, + cond_fn, + p_mean_var, + x, + t, + model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, p_mean_var, **model_kwargs) + new_mean = (paddle.cast((p_mean_var["mean"]), 'float32') + + p_mean_var["variance"] * paddle.cast((gradient), 'float32')) + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + x, self._scale_timesteps(t), **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def condition_score_with_grad(self, + cond_fn, + p_mean_var, + x, + t, + model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, p_mean_var, ** + model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + """ + out_orig = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, + out_orig, + x, + t, + model_kwargs=model_kwargs) + else: + out = out_orig + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, + x.shape) + sigma = (eta * paddle.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * + paddle.sqrt(1 - alpha_bar / alpha_bar_prev)) + # Equation 12. + # noise = th.randn_like(x) + noise = paddle.randn(x.shape, x.dtype) + mean_pred = (out["pred_xstart"] * paddle.sqrt(alpha_bar_prev) + + paddle.sqrt(1 - alpha_bar_prev - sigma**2) * eps) + nonzero_mask = (paddle.cast( + (t != 0), 'float32').reshape([-1, *([1] * (len(x.shape) - 1))]) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]} + + def ddim_sample_with_grad( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + """ + # with th.enable_grad(): + # x = x.detach().requires_grad_() + x = x.detach() + # x.stop_gradient = False + out_orig = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score_with_grad(cond_fn, + out_orig, + x, + t, + model_kwargs=model_kwargs) + else: + out = out_orig + + out["pred_xstart"] = out["pred_xstart"].detach() + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, + x.shape) + sigma = (eta * paddle.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * + paddle.sqrt(1 - alpha_bar / alpha_bar_prev)) + # Equation 12. + # noise = th.randn_like(x) + noise = paddle.randn(x.shape, x.dtype) + mean_pred = (out["pred_xstart"] * paddle.sqrt(alpha_bar_prev) + + paddle.sqrt(1 - alpha_bar_prev - sigma**2) * eps) + nonzero_mask = (paddle.cast( + (t != 0), 'float32').reshape([-1, *([1] * (len(x.shape) - 1))]) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return { + "sample": sample, + "pred_xstart": out_orig["pred_xstart"].detach() + } + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + progress=False, + eta=0.0, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + ): + """ + Generate samples from the model using DDIM. + + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + progress=progress, + eta=eta, + skip_timesteps=skip_timesteps, + init_image=init_image, + randomize_class=randomize_class, + cond_fn_with_grad=cond_fn_with_grad, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + progress=False, + eta=0.0, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + """ + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = paddle.randn(shape) + + if skip_timesteps and init_image is None: + init_image = paddle.zeros_like(img) + + indices = list(range(self.num_timesteps - skip_timesteps))[::-1] + + if init_image is not None: + my_t = paddle.ones([shape[0]], dtype='int64') * indices[0] + img = self.q_sample(init_image, my_t, img) + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = paddle.to_tensor([i] * shape[0]) + if randomize_class and 'y' in model_kwargs: + model_kwargs['y'] = paddle.randint( + low=0, + high=model.num_classes, + shape=model_kwargs['y'].shape, + ) + sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample + out = sample_fn( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim"):]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, + **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, + **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.rescale_timesteps, + self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + + def __init__(self, model, timestep_map, rescale_timesteps, + original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = paddle.to_tensor(self.timestep_map, + place=ts.place, + dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = paddle.cast(new_ts, + 'float32') * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = paddle.to_tensor(arr, place=timesteps.place)[timesteps] + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) diff --git a/paddlenlp/transformers/guided_diffusion_utils/losses.py b/paddlenlp/transformers/guided_diffusion_utils/losses.py new file mode 100755 index 000000000000..a43a53f2fca9 --- /dev/null +++ b/paddlenlp/transformers/guided_diffusion_utils/losses.py @@ -0,0 +1,26 @@ +""" +Helpers for various likelihood-based losses implemented by Paddle. These are ported from the original +Ho et al. diffusion models codebase: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py +""" +import paddle +import paddle.nn.functional as F + + +def spherical_dist_loss(x, y): + x = F.normalize(x, axis=-1) + y = F.normalize(y, axis=-1) + return (x - y).norm(axis=-1).divide( + paddle.to_tensor(2.0)).asin().pow(2).multiply(paddle.to_tensor(2.0)) + + +def tv_loss(input): + """L2 total variation loss, as in Mahendran et al.""" + input = F.pad(input, (0, 1, 0, 1), 'replicate') + x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] + y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] + return (x_diff**2 + y_diff**2).mean([1, 2, 3]) + + +def range_loss(input): + return (input - input.clip(-1, 1)).pow(2).mean([1, 2, 3]) diff --git a/paddlenlp/transformers/guided_diffusion_utils/make_cutouts.py b/paddlenlp/transformers/guided_diffusion_utils/make_cutouts.py new file mode 100755 index 000000000000..9d9831993923 --- /dev/null +++ b/paddlenlp/transformers/guided_diffusion_utils/make_cutouts.py @@ -0,0 +1,96 @@ +''' +This code is rewritten by Paddle based on Jina-ai/discoart. +https://github.com/jina-ai/discoart/blob/main/discoart/nn/make_cutouts.py +''' +import paddle +import paddle.nn as nn +from paddle.nn import functional as F +from .resize_right import resize + +from . import transforms as T + +skip_augs = False +padargs = {} + + +class MakeCutoutsDango(nn.Layer): + + def __init__(self, + cut_size, + Overview=4, + InnerCrop=0, + IC_Size_Pow=0.5, + IC_Grey_P=0.2): + super().__init__() + self.cut_size = cut_size + self.Overview = Overview + self.InnerCrop = InnerCrop + self.IC_Size_Pow = IC_Size_Pow + self.IC_Grey_P = IC_Grey_P + self.augs = nn.Sequential(*[ + T.RandomHorizontalFlip(prob=0.5), + T.Lambda(lambda x: x + paddle.randn(x.shape) * 0.01), + T.RandomAffine( + degrees=10, + translate=(0.05, 0.05), + interpolation=T.InterpolationMode.BILINEAR, + ), + T.Lambda(lambda x: x + paddle.randn(x.shape) * 0.01), + T.RandomGrayscale(p=0.1), + T.Lambda(lambda x: x + paddle.randn(x.shape) * 0.01), + T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, + hue=0.1), + ]) + + def forward(self, input): + cutouts = [] + gray = T.Grayscale(3) + sideY, sideX = input.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + output_shape = [1, 3, self.cut_size, self.cut_size] + pad_input = F.pad( + input, + ( + (sideY - max_size) // 2, + (sideY - max_size) // 2, + (sideX - max_size) // 2, + (sideX - max_size) // 2, + ), + **padargs, + ) + cutout = resize(pad_input, out_shape=output_shape) + + if self.Overview > 0: + if self.Overview <= 4: + if self.Overview >= 1: + cutouts.append(cutout) + if self.Overview >= 2: + cutouts.append(gray(cutout)) + if self.Overview >= 3: + cutouts.append(cutout[:, :, :, ::-1]) + if self.Overview == 4: + cutouts.append(gray(cutout[:, :, :, ::-1])) + else: + cutout = resize(pad_input, out_shape=output_shape) + for _ in range(self.Overview): + cutouts.append(cutout) + + if self.InnerCrop > 0: + for i in range(self.InnerCrop): + size = int( + paddle.rand([1])**self.IC_Size_Pow * (max_size - min_size) + + min_size) + offsetx = paddle.randint(0, sideX - size + 1) + offsety = paddle.randint(0, sideY - size + 1) + cutout = input[:, :, offsety:offsety + size, + offsetx:offsetx + size] + if i <= int(self.IC_Grey_P * self.InnerCrop): + cutout = gray(cutout) + cutout = resize(cutout, out_shape=output_shape) + cutouts.append(cutout) + + cutouts = paddle.concat(cutouts) + if skip_augs is not True: + cutouts = self.augs(cutouts) + return cutouts diff --git a/paddlenlp/transformers/guided_diffusion_utils/model_diffusion.py b/paddlenlp/transformers/guided_diffusion_utils/model_diffusion.py new file mode 100644 index 000000000000..c9ae17bcb9f7 --- /dev/null +++ b/paddlenlp/transformers/guided_diffusion_utils/model_diffusion.py @@ -0,0 +1,96 @@ +''' +This code is based on +https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/script_util.py +''' +from .gaussian_diffusion import get_named_beta_schedule, SpacedDiffusion, space_timesteps, ModelVarType, ModelMeanType +from .unet import UNetModel +from .sec_diff import SecondaryDiffusionImageNet2 + +NUM_CLASSES = 1000 + + +def create_unet_model( + image_size=512, + num_channels=256, + num_res_blocks=2, + channel_mult="", + learn_sigma=True, + class_cond=False, + attention_resolutions='32, 16, 8', + num_heads=4, + num_head_channels=64, + num_heads_upsample=-1, + use_scale_shift_norm=True, + dropout=0.0, + resblock_updown=True, + use_new_attention_order=False, +): + if channel_mult == "": + if image_size == 512: + channel_mult = (0.5, 1, 1, 2, 2, 4, 4) + elif image_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif image_size == 128: + channel_mult = (1, 1, 2, 3, 4) + elif image_size == 64: + channel_mult = (1, 2, 3, 4) + else: + raise ValueError(f"unsupported image size: {image_size}") + else: + channel_mult = tuple( + int(ch_mult) for ch_mult in channel_mult.split(",")) + + attention_ds = [] + for res in attention_resolutions.split(","): + attention_ds.append(image_size // int(res)) + + return UNetModel( + image_size=image_size, + in_channels=3, + model_channels=num_channels, + out_channels=(3 if not learn_sigma else 6), + num_res_blocks=num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=dropout, + channel_mult=channel_mult, + num_classes=(NUM_CLASSES if class_cond else None), + use_fp16=False, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + use_new_attention_order=use_new_attention_order, + ) + + +def create_secondary_model(): + model = SecondaryDiffusionImageNet2() + return model + + +def create_gaussian_diffusion( + steps=250, + learn_sigma=True, + sigma_small=False, + noise_schedule="linear", + predict_xstart=False, + rescale_timesteps=True, +): + # propcess steps + timestep_respacing = f'ddim{steps}' + steps = (1000 // steps) * steps if steps < 1000 else steps + + betas = get_named_beta_schedule(noise_schedule, steps) + if not timestep_respacing: + timestep_respacing = [steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(steps, timestep_respacing), + betas=betas, + model_mean_type=(ModelMeanType.EPSILON + if not predict_xstart else ModelMeanType.START_X), + model_var_type=((ModelVarType.FIXED_LARGE + if not sigma_small else ModelVarType.FIXED_SMALL) + if not learn_sigma else ModelVarType.LEARNED_RANGE), + rescale_timesteps=rescale_timesteps, + ) diff --git a/paddlenlp/transformers/guided_diffusion_utils/perlin_noises.py b/paddlenlp/transformers/guided_diffusion_utils/perlin_noises.py new file mode 100755 index 000000000000..fe1688974de8 --- /dev/null +++ b/paddlenlp/transformers/guided_diffusion_utils/perlin_noises.py @@ -0,0 +1,85 @@ +''' +Perlin noise implementation by Paddle. +This code is rewritten based on: +https://github.com/jina-ai/discoart/blob/main/discoart/nn/perlin_noises.py +''' +import numpy as np +import paddle +import paddle.vision.transforms as PF +from PIL import Image, ImageOps + + +def interp(t): + return 3 * t**2 - 2 * t**3 + + +def perlin(width, height, scale=10): + gx, gy = paddle.randn([2, width + 1, height + 1, 1, 1]) + xs = paddle.linspace(0, 1, scale + 1)[:-1, None] + ys = paddle.linspace(0, 1, scale + 1)[None, :-1] + wx = 1 - interp(xs) + wy = 1 - interp(ys) + dots = 0 + dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys) + dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys) + dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys)) + dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * + (1 - ys)) + return dots.transpose([0, 2, 1, 3]).reshape([width * scale, height * scale]) + + +def perlin_ms(octaves, width, height, grayscale): + out_array = [0.5] if grayscale else [0.5, 0.5, 0.5] + # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0] + for i in range(1 if grayscale else 3): + scale = 2**len(octaves) + oct_width = width + oct_height = height + for oct in octaves: + p = perlin(oct_width, oct_height, scale) + out_array[i] += p * oct + scale //= 2 + oct_width *= 2 + oct_height *= 2 + return paddle.concat(out_array) + + +def create_perlin_noise(octaves, width, height, grayscale, side_y, side_x): + out = perlin_ms(octaves, width, height, grayscale) + if grayscale: + out = PF.resize(size=(side_y, side_x), img=out.numpy()) + out = np.uint8(out) + out = Image.fromarray(out).convert('RGB') + else: + out = out.reshape([-1, 3, out.shape[0] // 3, out.shape[1]]) + out = out.squeeze().transpose([1, 2, 0]).numpy() + out = PF.resize(size=(side_y, side_x), img=out) + out = out.clip(0, 1) * 255 + out = np.uint8(out) + out = Image.fromarray(out) + + out = ImageOps.autocontrast(out) + return out + + +def regen_perlin(perlin_mode, side_y, side_x, batch_size): + if perlin_mode == 'color': + init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)], 1, 1, + False, side_y, side_x) + init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], 4, 4, + False, side_y, side_x) + elif perlin_mode == 'gray': + init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)], 1, 1, + True, side_y, side_x) + init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], 4, 4, + True, side_y, side_x) + else: + init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)], 1, 1, + False, side_y, side_x) + init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], 4, 4, + True, side_y, side_x) + + init = (PF.to_tensor(init).add(PF.to_tensor(init2)).divide( + paddle.to_tensor(2.0)).unsqueeze(0) * 2 - 1) + del init2 + return init.expand([batch_size, -1, -1, -1]) diff --git a/paddlenlp/transformers/guided_diffusion_utils/resize_right.py b/paddlenlp/transformers/guided_diffusion_utils/resize_right.py new file mode 100755 index 000000000000..a65ed0b9116d --- /dev/null +++ b/paddlenlp/transformers/guided_diffusion_utils/resize_right.py @@ -0,0 +1,477 @@ +from fractions import Fraction +from math import ceil + +from math import pi + +import paddle +import paddle.nn as nn +import numpy +import numpy as np + +nnModuleWrapped = nn.Layer + + +def set_framework_dependencies(x): + if type(x) is numpy.ndarray: + to_dtype = lambda a: a + fw = numpy + else: + to_dtype = lambda a: paddle.cast(a, x.dtype) + fw = paddle + # eps = fw.finfo(fw.float32).eps + eps = paddle.to_tensor(np.finfo(np.float32).eps) + return fw, to_dtype, eps + + +def support_sz(sz): + + def wrapper(f): + f.support_sz = sz + return f + + return wrapper + + +@support_sz(4) +def cubic(x): + fw, to_dtype, eps = set_framework_dependencies(x) + absx = fw.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) + + (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) * + to_dtype((1. < absx) & (absx <= 2.))) + + +@support_sz(4) +def lanczos2(x): + fw, to_dtype, eps = set_framework_dependencies(x) + return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) / + ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2)) + + +@support_sz(6) +def lanczos3(x): + fw, to_dtype, eps = set_framework_dependencies(x) + return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) / + ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3)) + + +@support_sz(2) +def linear(x): + fw, to_dtype, eps = set_framework_dependencies(x) + return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + + (1 - x) * to_dtype((0 <= x) & (x <= 1))) + + +@support_sz(1) +def box(x): + fw, to_dtype, eps = set_framework_dependencies(x) + return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1)) + + +def resize(input, + scale_factors=None, + out_shape=None, + interp_method=cubic, + support_sz=None, + antialiasing=True, + by_convs=False, + scale_tolerance=None, + max_numerator=10, + pad_mode='constant'): + # get properties of the input tensor + in_shape, n_dims = input.shape, input.ndim + + # fw stands for framework that can be either numpy or paddle, + # determined by the input type + fw = numpy if type(input) is numpy.ndarray else paddle + eps = np.finfo(np.float32).eps if fw == numpy else paddle.to_tensor( + np.finfo(np.float32).eps) + device = input.place if fw is paddle else None + + # set missing scale factors or output shapem one according to another, + # scream if both missing. this is also where all the defults policies + # take place. also handling the by_convs attribute carefully. + scale_factors, out_shape, by_convs = set_scale_and_out_sz( + in_shape, out_shape, scale_factors, by_convs, scale_tolerance, + max_numerator, eps, fw) + + # sort indices of dimensions according to scale of each dimension. + # since we are going dim by dim this is efficient + sorted_filtered_dims_and_scales = [ + (dim, scale_factors[dim], by_convs[dim], in_shape[dim], out_shape[dim]) + for dim in sorted(range(n_dims), key=lambda ind: scale_factors[ind]) + if scale_factors[dim] != 1. + ] + # unless support size is specified by the user, it is an attribute + # of the interpolation method + if support_sz is None: + support_sz = interp_method.support_sz + + # output begins identical to input and changes with each iteration + output = input + + # iterate over dims + for (dim, scale_factor, dim_by_convs, in_sz, + out_sz) in sorted_filtered_dims_and_scales: + # STEP 1- PROJECTED GRID: The non-integer locations of the projection + # of output pixel locations to the input tensor + projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw, + dim_by_convs, device) + + # STEP 1.5: ANTIALIASING- If antialiasing is taking place, we modify + # the window size and the interpolation method (see inside function) + cur_interp_method, cur_support_sz = apply_antialiasing_if_needed( + interp_method, support_sz, scale_factor, antialiasing) + + # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels + # that influence it. Also calculate needed padding and update grid + # accoedingly + field_of_view = get_field_of_view(projected_grid, cur_support_sz, fw, + eps, device) + + # STEP 2.5- CALCULATE PAD AND UPDATE: according to the field of view, + # the input should be padded to handle the boundaries, coordinates + # should be updated. actual padding only occurs when weights are + # aplied (step 4). if using by_convs for this dim, then we need to + # calc right and left boundaries for each filter instead. + pad_sz, projected_grid, field_of_view = calc_pad_sz( + in_sz, out_sz, field_of_view, projected_grid, scale_factor, + dim_by_convs, fw, device) + # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in + # the field of view for each output pixel + weights = get_weights(cur_interp_method, projected_grid, field_of_view) + + # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying + # its set of weights with the pixel values in its field of view. + # We now multiply the fields of view with their matching weights. + # We do this by tensor multiplication and broadcasting. + # if by_convs is true for this dim, then we do this action by + # convolutions. this is equivalent but faster. + if not dim_by_convs: + output = apply_weights(output, field_of_view, weights, dim, n_dims, + pad_sz, pad_mode, fw) + else: + output = apply_convs(output, scale_factor, in_sz, out_sz, weights, + dim, pad_sz, pad_mode, fw) + return output + + +def get_projected_grid(in_sz, out_sz, scale_factor, fw, by_convs, device=None): + # we start by having the ouput coordinates which are just integer locations + # in the special case when usin by_convs, we only need two cycles of grid + # points. the first and last. + grid_sz = out_sz if not by_convs else scale_factor.numerator + out_coordinates = fw_arange(grid_sz, fw, device) + + # This is projecting the ouput pixel locations in 1d to the input tensor, + # as non-integer locations. + # the following fomrula is derived in the paper + # "From Discrete to Continuous Convolutions" by Shocher et al. + return (out_coordinates / float(scale_factor) + (in_sz - 1) / 2 - + (out_sz - 1) / (2 * float(scale_factor))) + + +def get_field_of_view(projected_grid, cur_support_sz, fw, eps, device): + # for each output pixel, map which input pixels influence it, in 1d. + # we start by calculating the leftmost neighbor, using half of the window + # size (eps is for when boundary is exact int) + left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw) + + # then we simply take all the pixel centers in the field by counting + # window size pixels from the left boundary + ordinal_numbers = fw_arange(ceil(cur_support_sz - eps), fw, device) + return left_boundaries[:, None] + ordinal_numbers + + +def calc_pad_sz(in_sz, out_sz, field_of_view, projected_grid, scale_factor, + dim_by_convs, fw, device): + if not dim_by_convs: + # determine padding according to neighbor coords out of bound. + # this is a generalized notion of padding, when pad<0 it means crop + pad_sz = [ + -field_of_view[0, 0].item(), + field_of_view[-1, -1].item() - in_sz + 1 + ] + + # since input image will be changed by padding, coordinates of both + # field_of_view and projected_grid need to be updated + field_of_view += pad_sz[0] + projected_grid += pad_sz[0] + + else: + # only used for by_convs, to calc the boundaries of each filter the + # number of distinct convolutions is the numerator of the scale factor + num_convs, stride = scale_factor.numerator, scale_factor.denominator + + # calculate left and right boundaries for each conv. left can also be + # negative right can be bigger than in_sz. such cases imply padding if + # needed. however if# both are in-bounds, it means we need to crop, + # practically apply the conv only on part of the image. + left_pads = -field_of_view[:, 0] + + # next calc is tricky, explanation by rows: + # 1) counting output pixels between the first position of each filter + # to the right boundary of the input + # 2) dividing it by number of filters to count how many 'jumps' + # each filter does + # 3) multiplying by the stride gives us the distance over the input + # coords done by all these jumps for each filter + # 4) to this distance we add the right boundary of the filter when + # placed in its leftmost position. so now we get the right boundary + # of that filter in input coord. + # 5) the padding size needed is obtained by subtracting the rightmost + # input coordinate. if the result is positive padding is needed. if + # negative then negative padding means shaving off pixel columns. + right_pads = (((out_sz - fw_arange(num_convs, fw, device) - 1) # (1) + // num_convs) # (2) + * stride # (3) + + field_of_view[:, -1] # (4) + - in_sz + 1) # (5) + + # in the by_convs case pad_sz is a list of left-right pairs. one per + # each filter + + pad_sz = list(zip(left_pads, right_pads)) + + return pad_sz, projected_grid, field_of_view + + +def get_weights(interp_method, projected_grid, field_of_view): + # the set of weights per each output pixels is the result of the chosen + # interpolation method applied to the distances between projected grid + # locations and the pixel-centers in the field of view (distances are + # directed, can be positive or negative) + weights = interp_method(projected_grid[:, None] - field_of_view) + + # we now carefully normalize the weights to sum to 1 per each output pixel + sum_weights = weights.sum(1, keepdim=True) + sum_weights[sum_weights == 0] = 1 + return weights / sum_weights + + +def apply_weights(input, field_of_view, weights, dim, n_dims, pad_sz, pad_mode, + fw): + # for this operation we assume the resized dim is the first one. + # so we transpose and will transpose back after multiplying + tmp_input = fw_swapaxes(input, dim, 0, fw) + + # apply padding + tmp_input = fw_pad(tmp_input, fw, pad_sz, pad_mode) + + # field_of_view is a tensor of order 2: for each output (1d location + # along cur dim)- a list of 1d neighbors locations. + # note that this whole operations is applied to each dim separately, + # this is why it is all in 1d. + # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1: + # for each output pixel (this time indicated in all dims), these are the + # values of the neighbors in the 1d field of view. note that we only + # consider neighbors along the current dim, but such set exists for every + # multi-dim location, hence the final tensor order is image_dims+1. + paddle.device.cuda.empty_cache() + neighbors = tmp_input[field_of_view] + + # weights is an order 2 tensor: for each output location along 1d- a list + # of weights matching the field of view. we augment it with ones, for + # broadcasting, so that when multiplies some tensor the weights affect + # only its first dim. + tmp_weights = fw.reshape(weights, (*weights.shape, *[1] * (n_dims - 1))) + + # now we simply multiply the weights with the neighbors, and then sum + # along the field of view, to get a single value per out pixel + tmp_output = (neighbors * tmp_weights).sum(1) + # we transpose back the resized dim to its original position + return fw_swapaxes(tmp_output, 0, dim, fw) + + +def apply_convs(input, scale_factor, in_sz, out_sz, weights, dim, pad_sz, + pad_mode, fw): + # for this operations we assume the resized dim is the last one. + # so we transpose and will transpose back after multiplying + input = fw_swapaxes(input, dim, -1, fw) + + # the stride for all convs is the denominator of the scale factor + stride, num_convs = scale_factor.denominator, scale_factor.numerator + + # prepare an empty tensor for the output + tmp_out_shape = list(input.shape) + tmp_out_shape[-1] = out_sz + tmp_output = fw_empty(tuple(tmp_out_shape), fw, input.device) + + # iterate over the conv operations. we have as many as the numerator + # of the scale-factor. for each we need boundaries and a filter. + for conv_ind, (pad_sz, filt) in enumerate(zip(pad_sz, weights)): + # apply padding (we pad last dim, padding can be negative) + pad_dim = input.ndim - 1 + tmp_input = fw_pad(input, fw, pad_sz, pad_mode, dim=pad_dim) + + # apply convolution over last dim. store in the output tensor with + # positional strides so that when the loop is comlete conv results are + # interwind + tmp_output[..., conv_ind::num_convs] = fw_conv(tmp_input, filt, stride) + + return fw_swapaxes(tmp_output, -1, dim, fw) + + +def set_scale_and_out_sz(in_shape, out_shape, scale_factors, by_convs, + scale_tolerance, max_numerator, eps, fw): + # eventually we must have both scale-factors and out-sizes for all in/out + # dims. however, we support many possible partial arguments + if scale_factors is None and out_shape is None: + raise ValueError("either scale_factors or out_shape should be " + "provided") + if out_shape is not None: + # if out_shape has less dims than in_shape, we defaultly resize the + # first dims for numpy and last dims for paddle + out_shape = (list(out_shape) + list(in_shape[len(out_shape):]) + if fw is numpy else list(in_shape[:-len(out_shape)]) + + list(out_shape)) + if scale_factors is None: + # if no scale given, we calculate it as the out to in ratio + # (not recomended) + scale_factors = [ + out_sz / in_sz for out_sz, in_sz in zip(out_shape, in_shape) + ] + if scale_factors is not None: + # by default, if a single number is given as scale, we assume resizing + # two dims (most common are images with 2 spatial dims) + scale_factors = (scale_factors if isinstance( + scale_factors, (list, tuple)) else [scale_factors, scale_factors]) + # if less scale_factors than in_shape dims, we defaultly resize the + # first dims for numpy and last dims for paddle + scale_factors = (list(scale_factors) + [1] * + (len(in_shape) - len(scale_factors)) if fw is numpy + else [1] * (len(in_shape) - len(scale_factors)) + + list(scale_factors)) + if out_shape is None: + # when no out_shape given, it is calculated by multiplying the + # scale by the in_shape (not recomended) + out_shape = [ + ceil(scale_factor * in_sz) + for scale_factor, in_sz in zip(scale_factors, in_shape) + ] + # next part intentionally after out_shape determined for stability + # we fix by_convs to be a list of truth values in case it is not + if not isinstance(by_convs, (list, tuple)): + by_convs = [by_convs] * len(out_shape) + + # next loop fixes the scale for each dim to be either frac or float. + # this is determined by by_convs and by tolerance for scale accuracy. + for ind, (sf, dim_by_convs) in enumerate(zip(scale_factors, by_convs)): + # first we fractionaize + if dim_by_convs: + frac = Fraction(1 / sf).limit_denominator(max_numerator) + frac = Fraction(numerator=frac.denominator, + denominator=frac.numerator) + + # if accuracy is within tolerance scale will be frac. if not, then + # it will be float and the by_convs attr will be set false for + # this dim + if scale_tolerance is None: + scale_tolerance = eps + if dim_by_convs and abs(frac - sf) < scale_tolerance: + scale_factors[ind] = frac + else: + scale_factors[ind] = float(sf) + by_convs[ind] = False + + return scale_factors, out_shape, by_convs + + +def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor, + antialiasing): + # antialiasing is "stretching" the field of view according to the scale + # factor (only for downscaling). this is low-pass filtering. this + # requires modifying both the interpolation (stretching the 1d + # function and multiplying by the scale-factor) and the window size. + scale_factor = float(scale_factor) + if scale_factor >= 1.0 or not antialiasing: + return interp_method, support_sz + cur_interp_method = ( + lambda arg: scale_factor * interp_method(scale_factor * arg)) + cur_support_sz = support_sz / scale_factor + return cur_interp_method, cur_support_sz + + +def fw_ceil(x, fw): + if fw is numpy: + return fw.int_(fw.ceil(x)) + else: + return paddle.cast(x.ceil(), dtype='int64') + + +def fw_floor(x, fw): + if fw is numpy: + return fw.int_(fw.floor(x)) + else: + return paddle.cast(x.floor(), dtype='int64') + + +def fw_cat(x, fw): + if fw is numpy: + return fw.concatenate(x) + else: + return fw.concat(x) + + +def fw_swapaxes(x, ax_1, ax_2, fw): + if fw is numpy: + return fw.swapaxes(x, ax_1, ax_2) + else: + if ax_1 == -1: + ax_1 = len(x.shape) - 1 + if ax_2 == -1: + ax_2 = len(x.shape) - 1 + perm0 = list(range(len(x.shape))) + temp = ax_1 + perm0[temp] = ax_2 + perm0[ax_2] = temp + return fw.transpose(x, perm0) + + +def fw_pad(x, fw, pad_sz, pad_mode, dim=0): + if pad_sz == (0, 0): + return x + if fw is numpy: + pad_vec = [(0, 0)] * x.ndim + pad_vec[dim] = pad_sz + return fw.pad(x, pad_width=pad_vec, mode=pad_mode) + else: + if x.ndim < 3: + x = x[None, None, ...] + + pad_vec = [0] * ((x.ndim - 2) * 2) + pad_vec[0:2] = pad_sz + return fw_swapaxes( + fw.nn.functional.pad(fw_swapaxes(x, dim, -1, fw), + pad=pad_vec, + mode=pad_mode), dim, -1, fw) + + +def fw_conv(input, filter, stride): + # we want to apply 1d conv to any nd array. the way to do it is to reshape + # the input to a 4D tensor. first two dims are singeletons, 3rd dim stores + # all the spatial dims that we are not convolving along now. then we can + # apply conv2d with a 1xK filter. This convolves the same way all the other + # dims stored in the 3d dim. like depthwise conv over these. + # TODO: numpy support + reshaped_input = input.reshape(1, 1, -1, input.shape[-1]) + reshaped_output = paddle.nn.functional.conv2d(reshaped_input, + filter.view(1, 1, 1, -1), + stride=(1, stride)) + return reshaped_output.reshape(*input.shape[:-1], -1) + + +def fw_arange(upper_bound, fw, device): + if fw is numpy: + return fw.arange(upper_bound) + else: + return fw.arange(upper_bound) + + +def fw_empty(shape, fw, device): + if fw is numpy: + return fw.empty(shape) + else: + return fw.empty(shape=shape) diff --git a/paddlenlp/transformers/guided_diffusion_utils/sec_diff.py b/paddlenlp/transformers/guided_diffusion_utils/sec_diff.py new file mode 100644 index 000000000000..84875cc1779d --- /dev/null +++ b/paddlenlp/transformers/guided_diffusion_utils/sec_diff.py @@ -0,0 +1,139 @@ +''' +This code is rewritten by Paddle based on +https://github.com/jina-ai/discoart/blob/main/discoart/nn/sec_diff.py +''' +import math +from dataclasses import dataclass +from functools import partial + +import paddle +import paddle.nn as nn + + +@dataclass +class DiffusionOutput: + v: paddle.Tensor + pred: paddle.Tensor + eps: paddle.Tensor + + +class SkipBlock(nn.Layer): + + def __init__(self, main, skip=None): + super().__init__() + self.main = nn.Sequential(*main) + self.skip = skip if skip else nn.Identity() + + def forward(self, input): + return paddle.concat([self.main(input), self.skip(input)], axis=1) + + +def append_dims(x, n): + return x[(Ellipsis, *(None, ) * (n - x.ndim))] + + +def expand_to_planes(x, shape): + return paddle.tile(append_dims(x, len(shape)), [1, 1, *shape[2:]]) + + +def alpha_sigma_to_t(alpha, sigma): + return paddle.atan2(sigma, alpha) * 2 / math.pi + + +def t_to_alpha_sigma(t): + return paddle.cos(t * math.pi / 2), paddle.sin(t * math.pi / 2) + + +class SecondaryDiffusionImageNet2(nn.Layer): + + def __init__(self): + super().__init__() + c = 64 # The base channel count + cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8] + + self.timestep_embed = FourierFeatures(1, 16) + self.down = nn.AvgPool2D(2) + self.up = nn.Upsample(scale_factor=2, + mode='bilinear', + align_corners=False) + + self.net = nn.Sequential( + ConvBlock(3 + 16, cs[0]), + ConvBlock(cs[0], cs[0]), + SkipBlock([ + self.down, + ConvBlock(cs[0], cs[1]), + ConvBlock(cs[1], cs[1]), + SkipBlock([ + self.down, + ConvBlock(cs[1], cs[2]), + ConvBlock(cs[2], cs[2]), + SkipBlock([ + self.down, + ConvBlock(cs[2], cs[3]), + ConvBlock(cs[3], cs[3]), + SkipBlock([ + self.down, + ConvBlock(cs[3], cs[4]), + ConvBlock(cs[4], cs[4]), + SkipBlock([ + self.down, + ConvBlock(cs[4], cs[5]), + ConvBlock(cs[5], cs[5]), + ConvBlock(cs[5], cs[5]), + ConvBlock(cs[5], cs[4]), + self.up, + ]), + ConvBlock(cs[4] * 2, cs[4]), + ConvBlock(cs[4], cs[3]), + self.up, + ]), + ConvBlock(cs[3] * 2, cs[3]), + ConvBlock(cs[3], cs[2]), + self.up, + ]), + ConvBlock(cs[2] * 2, cs[2]), + ConvBlock(cs[2], cs[1]), + self.up, + ]), + ConvBlock(cs[1] * 2, cs[1]), + ConvBlock(cs[1], cs[0]), + self.up, + ]), + ConvBlock(cs[0] * 2, cs[0]), + nn.Conv2D(cs[0], 3, 3, padding=1), + ) + + def forward(self, input, t): + timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), + input.shape) + v = self.net(paddle.concat([input, timestep_embed], axis=1)) + alphas, sigmas = map(partial(append_dims, n=v.ndim), + t_to_alpha_sigma(t)) + pred = input * alphas - v * sigmas + eps = input * sigmas + v * alphas + return DiffusionOutput(v, pred, eps) + + +class FourierFeatures(nn.Layer): + + def __init__(self, in_features, out_features, std=1.0): + super().__init__() + assert out_features % 2 == 0 + self.weight = paddle.create_parameter( + [out_features // 2, in_features], + dtype='float32', + default_initializer=nn.initializer.Normal(mean=0.0, std=std)) + + def forward(self, input): + f = 2 * math.pi * input @ self.weight.T + return paddle.concat([f.cos(), f.sin()], axis=-1) + + +class ConvBlock(nn.Sequential): + + def __init__(self, c_in, c_out): + super().__init__( + nn.Conv2D(c_in, c_out, 3, padding=1), + nn.ReLU(), + ) diff --git a/paddlenlp/transformers/guided_diffusion_utils/transforms.py b/paddlenlp/transformers/guided_diffusion_utils/transforms.py new file mode 100755 index 000000000000..022be4688a92 --- /dev/null +++ b/paddlenlp/transformers/guided_diffusion_utils/transforms.py @@ -0,0 +1,806 @@ +''' +This code is rewritten by Paddle based on +https://github.com/pytorch/vision/blob/main/torchvision/transforms/transforms.py +''' +import math +import numbers +import warnings +from enum import Enum +from typing import List, Optional, Sequence, Tuple + +import paddle +import paddle.nn as nn +from paddle.nn.functional import grid_sample + + +class Normalize(nn.Layer): + + def __init__(self, mean, std): + super(Normalize, self).__init__() + self.mean = paddle.to_tensor(mean) + self.std = paddle.to_tensor(std) + + def forward(self, tensor): + dtype = tensor.dtype + mean = paddle.cast(self.mean, dtype=dtype).reshape([1, -1, 1, 1]) + std = paddle.cast(self.std, dtype=dtype).reshape([1, -1, 1, 1]) + result = tensor.subtract(mean).divide(std) + return result + + +class InterpolationMode(Enum): + """Interpolation modes + Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``. + """ + + NEAREST = "nearest" + BILINEAR = "bilinear" + BICUBIC = "bicubic" + # For PIL compatibility + BOX = "box" + HAMMING = "hamming" + LANCZOS = "lanczos" + + +class Grayscale(nn.Layer): + + def __init__(self, num_output_channels): + super(Grayscale, self).__init__() + self.num_output_channels = num_output_channels + + def forward(self, x): + output = (0.2989 * x[:, 0:1, :, :] + 0.587 * x[:, 1:2, :, :] + + 0.114 * x[:, 2:3, :, :]) + if self.num_output_channels == 3: + return output.expand(x.shape) + + return output + + +class Lambda(nn.Layer): + + def __init__(self, func): + super(Lambda, self).__init__() + self.transform = func + + def forward(self, x): + return self.transform(x) + + +class RandomGrayscale(nn.Layer): + + def __init__(self, p): + super(RandomGrayscale, self).__init__() + self.prob = p + self.transform = Grayscale(3) + + def forward(self, x): + if paddle.rand([1]) < self.prob: + return self.transform(x) + else: + return x + + +class RandomHorizontalFlip(nn.Layer): + + def __init__(self, prob): + super(RandomHorizontalFlip, self).__init__() + self.prob = prob + + def forward(self, x): + if paddle.rand([1]) < self.prob: + return x[:, :, :, ::-1] + else: + return x + + +def _blend(img1, img2, ratio: float): + ratio = float(ratio) + bound = 1.0 + return (ratio * img1 + (1.0 - ratio) * img2).clip(0, bound) + + +def trunc_div(a, b): + ipt = paddle.divide(a, b) + sign_ipt = paddle.sign(ipt) + abs_ipt = paddle.abs(ipt) + abs_ipt = paddle.floor(abs_ipt) + out = paddle.multiply(sign_ipt, abs_ipt) + return out + + +def fmod(a, b): + return a - trunc_div(a, b) * b + + +def _rgb2hsv(img): + r, g, b = img.unbind(axis=-3) + + # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/ + # src/libImaging/Convert.c#L330 + maxc = paddle.max(img, axis=-3) + minc = paddle.min(img, axis=-3) + + # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN + # from happening in the results, because + # + S channel has division by `maxc`, which is zero only if `maxc = minc` + # + H channel has division by `(maxc - minc)`. + # + # Instead of overwriting NaN afterwards, we just prevent it from occuring so + # we don't need to deal with it in case we save the NaN in a buffer in + # backprop, if it is ever supported, but it doesn't hurt to do so. + eqc = maxc == minc + + cr = maxc - minc + # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine. + ones = paddle.ones_like(maxc) + s = cr / paddle.where(eqc, ones, maxc) + # Note that `eqc => maxc = minc = r = g = b`. So the following calculation + # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it + # would not matter what values `rc`, `gc`, and `bc` have here, and thus + # replacing denominator with 1 when `eqc` is fine. + cr_divisor = paddle.where(eqc, ones, cr) + rc = (maxc - r) / cr_divisor + gc = (maxc - g) / cr_divisor + bc = (maxc - b) / cr_divisor + + hr = (maxc == r).cast('float32') * (bc - gc) + hg = ((maxc == g) & (maxc != r)).cast('float32') * (2.0 + rc - bc) + hb = ((maxc != g) & (maxc != r)).cast('float32') * (4.0 + gc - rc) + h = hr + hg + hb + h = fmod((h / 6.0 + 1.0), paddle.to_tensor(1.0)) + return paddle.stack((h, s, maxc), axis=-3) + + +def _hsv2rgb(img): + h, s, v = img.unbind(axis=-3) + i = paddle.floor(h * 6.0) + f = (h * 6.0) - i + i = i.cast(dtype='int32') + + p = paddle.clip((v * (1.0 - s)), 0.0, 1.0) + q = paddle.clip((v * (1.0 - s * f)), 0.0, 1.0) + t = paddle.clip((v * (1.0 - s * (1.0 - f))), 0.0, 1.0) + i = i % 6 + + mask = i.unsqueeze(axis=-3) == paddle.arange(6).reshape([-1, 1, 1]) + + a1 = paddle.stack((v, q, p, p, t, v), axis=-3) + a2 = paddle.stack((t, v, v, q, p, p), axis=-3) + a3 = paddle.stack((p, p, t, v, v, q), axis=-3) + a4 = paddle.stack((a1, a2, a3), axis=-4) + + return paddle.einsum("...ijk, ...xijk -> ...xjk", + mask.cast(dtype=img.dtype), a4) + + +def adjust_brightness(img, brightness_factor: float): + if brightness_factor < 0: + raise ValueError( + f"brightness_factor ({brightness_factor}) is not non-negative.") + + return _blend(img, paddle.zeros_like(img), brightness_factor) + + +def adjust_contrast(img, contrast_factor: float): + if contrast_factor < 0: + raise ValueError( + f"contrast_factor ({contrast_factor}) is not non-negative.") + + c = img.shape[1] + + if c == 3: + output = (0.2989 * img[:, 0:1, :, :] + 0.587 * img[:, 1:2, :, :] + + 0.114 * img[:, 2:3, :, :]) + mean = paddle.mean(output, axis=(-3, -2, -1), keepdim=True) + + else: + mean = paddle.mean(img, axis=(-3, -2, -1), keepdim=True) + + return _blend(img, mean, contrast_factor) + + +def adjust_hue(img, hue_factor: float): + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") + + img = _rgb2hsv(img) + h, s, v = img.unbind(axis=-3) + h = fmod(h + hue_factor, paddle.to_tensor(1.0)) + img = paddle.stack((h, s, v), axis=-3) + img_hue_adj = _hsv2rgb(img) + return img_hue_adj + + +def adjust_saturation(img, saturation_factor: float): + if saturation_factor < 0: + raise ValueError( + f"saturation_factor ({saturation_factor}) is not non-negative.") + + output = (0.2989 * img[:, 0:1, :, :] + 0.587 * img[:, 1:2, :, :] + + 0.114 * img[:, 2:3, :, :]) + + return _blend(img, output, saturation_factor) + + +class ColorJitter(nn.Layer): + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + super(ColorJitter, self).__init__() + self.brightness = self._check_input(brightness, "brightness") + self.contrast = self._check_input(contrast, "contrast") + self.saturation = self._check_input(saturation, "saturation") + self.hue = self._check_input(hue, + "hue", + center=0, + bound=(-0.5, 0.5), + clip_first_on_zero=False) + + def _check_input(self, + value, + name, + center=1, + bound=(0, float("inf")), + clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError( + f"If {name} is a single number, it must be non negative.") + value = [center - float(value), center + float(value)] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"{name} values should be between {bound}") + else: + raise TypeError( + f"{name} should be a single number or a list/tuple with length 2." + ) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value + + @staticmethod + def get_params( + brightness: Optional[List[float]], + contrast: Optional[List[float]], + saturation: Optional[List[float]], + hue: Optional[List[float]], + ) -> Tuple[paddle.Tensor, Optional[float], Optional[float], Optional[float], + Optional[float]]: + """Get the parameters for the randomized transform to be applied on image. + + Args: + brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen + uniformly. Pass None to turn off the transformation. + contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen + uniformly. Pass None to turn off the transformation. + saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen + uniformly. Pass None to turn off the transformation. + hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. + Pass None to turn off the transformation. + + Returns: + tuple: The parameters used to apply the randomized transform + along with their random order. + """ + fn_idx = paddle.randperm(4) + + b = None if brightness is None else paddle.empty([1]).uniform_( + brightness[0], brightness[1]) + c = None if contrast is None else paddle.empty([1]).uniform_( + contrast[0], contrast[1]) + s = None if saturation is None else paddle.empty([1]).uniform_( + saturation[0], saturation[1]) + h = None if hue is None else paddle.empty([1]).uniform_(hue[0], hue[1]) + + return fn_idx, b, c, s, h + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Input image. + + Returns: + PIL Image or Tensor: Color jittered image. + """ + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue) + + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = adjust_hue(img, hue_factor) + + return img + + def __repr__(self) -> str: + s = (f"{self.__class__.__name__}(" + f"brightness={self.brightness}" + f", contrast={self.contrast}" + f", saturation={self.saturation}" + f", hue={self.hue})") + return s + + +def _apply_grid_transform(img, grid, mode: str, fill: Optional[List[float]]): + + if img.shape[0] > 1: + # Apply same grid to a batch of images + grid = grid.expand( + [img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3]]) + + # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice + if fill is not None: + dummy = paddle.ones((img.shape[0], 1, img.shape[2], img.shape[3]), + dtype=img.dtype) + img = paddle.concat((img, dummy), axis=1) + + img = grid_sample(img, + grid, + mode=mode, + padding_mode="zeros", + align_corners=False) + + # Fill with required color + if fill is not None: + mask = img[:, -1:, :, :] # N * 1 * H * W + img = img[:, :-1, :, :] # N * C * H * W + mask = mask.expand_as(img) + len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1 + fill_img = paddle.to_tensor(fill, dtype=img.dtype).reshape( + [1, len_fill, 1, 1]).expand_as(img) + if mode == "nearest": + mask = mask < 0.5 + img[mask] = fill_img[mask] + else: # 'bilinear' + img = img * mask + (1.0 - mask) * fill_img + return img + + +def _gen_affine_grid( + theta, + w: int, + h: int, + ow: int, + oh: int, +): + # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/ + # AffineGridGenerator.cpp#L18 + # Difference with AffineGridGenerator is that: + # 1) we normalize grid values after applying theta + # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate + + d = 0.5 + base_grid = paddle.empty([1, oh, ow, 3], dtype=theta.dtype) + x_grid = paddle.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, num=ow) + base_grid[..., 0] = (x_grid) + y_grid = paddle.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, + num=oh).unsqueeze_(-1) + base_grid[..., 1] = (y_grid) + base_grid[..., 2] = 1.0 + rescaled_theta = theta.transpose([0, 2, 1]) / paddle.to_tensor( + [0.5 * w, 0.5 * h], dtype=theta.dtype) + output_grid = base_grid.reshape([1, oh * ow, 3]).bmm(rescaled_theta) + return output_grid.reshape([1, oh, ow, 2]) + + +def affine_impl(img, + matrix: List[float], + interpolation: str = "nearest", + fill: Optional[List[float]] = None): + theta = paddle.to_tensor(matrix, dtype=img.dtype).reshape([1, 2, 3]) + shape = img.shape + # grid will be generated on the same device as theta and img + grid = _gen_affine_grid(theta, + w=shape[-1], + h=shape[-2], + ow=shape[-1], + oh=shape[-2]) + return _apply_grid_transform(img, grid, interpolation, fill=fill) + + +def _get_inverse_affine_matrix(center: List[float], + angle: float, + translate: List[float], + scale: float, + shear: List[float], + inverted: bool = True) -> List[float]: + # Helper method to compute inverse matrix for affine transformation + + # Pillow requires inverse affine transformation matrix: + # Affine matrix is : M = T * C * RotateScaleShear * C^-1 + # + # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] + # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] + # RotateScaleShear is rotation with scale and shear matrix + # + # RotateScaleShear(a, s, (sx, sy)) = + # = R(a) * S(s) * SHy(sy) * SHx(sx) + # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ] + # [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ] + # [ 0 , 0 , 1 ] + # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears: + # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0] + # [0, 1 ] [-tan(s), 1] + # + # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1 + + rot = math.radians(angle) + sx = math.radians(shear[0]) + sy = math.radians(shear[1]) + + cx, cy = center + tx, ty = translate + + # RSS without scaling + a = math.cos(rot - sy) / math.cos(sy) + b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot) + c = math.sin(rot - sy) / math.cos(sy) + d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot) + + if inverted: + # Inverted rotation matrix with scale and shear + # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 + matrix = [d, -b, 0.0, -c, a, 0.0] + matrix = [x / scale for x in matrix] + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty) + matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty) + # Apply center translation: C * RSS^-1 * C^-1 * T^-1 + matrix[2] += cx + matrix[5] += cy + else: + matrix = [a, b, 0.0, c, d, 0.0] + matrix = [x * scale for x in matrix] + # Apply inverse of center translation: RSS * C^-1 + matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy) + matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy) + # Apply translation and center : T * C * RSS * C^-1 + matrix[2] += cx + tx + matrix[5] += cy + ty + + return matrix + + +def affine( + img, + angle: float, + translate: List[int], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + resample: Optional[int] = None, + fillcolor: Optional[List[float]] = None, + center: Optional[List[int]] = None, +): + """Apply affine transformation on the image keeping image center invariant. + If the image is paddle Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + img (PIL Image or Tensor): image to transform. + angle (number): rotation angle in degrees between -180 and 180, clockwise direction. + translate (sequence of integers): horizontal and vertical translations (post-rotation translation) + scale (float): overall scale + shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction. + If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while + the second value corresponds to a shear parallel to the y axis. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, + but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + + .. note:: + In torchscript mode single int/float value is not supported, please use a sequence + of length 1: ``[value, ]``. + fillcolor (sequence or number, optional): + .. warning:: + This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``fill`` instead. + resample (int, optional): + .. warning:: + This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``interpolation`` + instead. + center (sequence, optional): Optional center of rotation. Origin is the upper left corner. + Default is the center of the image. + + Returns: + PIL Image or Tensor: Transformed image. + """ + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " + "Please use InterpolationMode enum.") + interpolation = _interpolation_modes_from_int(interpolation) + + if fillcolor is not None: + warnings.warn( + "The parameter 'fillcolor' is deprecated since 0.12 and will be removed in 0.14. " + "Please use 'fill' instead.") + fill = fillcolor + + if not isinstance(angle, (int, float)): + raise TypeError("Argument angle should be int or float") + + if not isinstance(translate, (list, tuple)): + raise TypeError("Argument translate should be a sequence") + + if len(translate) != 2: + raise ValueError("Argument translate should be a sequence of length 2") + + if scale <= 0.0: + raise ValueError("Argument scale should be positive") + + if not isinstance(shear, (numbers.Number, (list, tuple))): + raise TypeError( + "Shear should be either a single value or a sequence of two values") + + if not isinstance(interpolation, InterpolationMode): + raise TypeError("Argument interpolation should be a InterpolationMode") + + if isinstance(angle, int): + angle = float(angle) + + if isinstance(translate, tuple): + translate = list(translate) + + if isinstance(shear, numbers.Number): + shear = [shear, 0.0] + + if isinstance(shear, tuple): + shear = list(shear) + + if len(shear) == 1: + shear = [shear[0], shear[0]] + + if len(shear) != 2: + raise ValueError( + f"Shear should be a sequence containing two values. Got {shear}") + + if center is not None and not isinstance(center, (list, tuple)): + raise TypeError("Argument center should be a sequence") + center_f = [0.0, 0.0] + if center is not None: + _, height, width = img.shape[0], img.shape[1], img.shape[2] + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [ + 1.0 * (c - s * 0.5) for c, s in zip(center, [width, height]) + ] + + translate_f = [1.0 * t for t in translate] + matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, + shear) + return affine_impl(img, + matrix=matrix, + interpolation=interpolation.value, + fill=fill) + + +def _interpolation_modes_from_int(i: int) -> InterpolationMode: + inverse_modes_mapping = { + 0: InterpolationMode.NEAREST, + 2: InterpolationMode.BILINEAR, + 3: InterpolationMode.BICUBIC, + 4: InterpolationMode.BOX, + 5: InterpolationMode.HAMMING, + 1: InterpolationMode.LANCZOS, + } + return inverse_modes_mapping[i] + + +def _check_sequence_input(x, name, req_sizes): + msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join( + [str(s) for s in req_sizes]) + if not isinstance(x, Sequence): + raise TypeError(f"{name} should be a sequence of length {msg}.") + if len(x) not in req_sizes: + raise ValueError(f"{name} should be sequence of length {msg}.") + + +def _setup_angle(x, name, req_sizes=(2, )): + if isinstance(x, numbers.Number): + if x < 0: + raise ValueError( + f"If {name} is a single number, it must be positive.") + x = [-x, x] + else: + _check_sequence_input(x, name, req_sizes) + + return [float(d) for d in x] + + +class RandomAffine(nn.Layer): + """Random affine transformation of the image keeping center invariant. + If the image is paddle Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + degrees (sequence or number): Range of degrees to select from. + If degrees is a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees). Set to 0 to deactivate rotations. + translate (tuple, optional): tuple of maximum absolute fraction for horizontal + and vertical translations. For example translate=(a, b), then horizontal shift + is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is + randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. + scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is + randomly sampled from the range a <= scale <= b. Will keep original scale by default. + shear (sequence or number, optional): Range of degrees to select from. + If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) + will be applied. Else if shear is a sequence of 2 values a shear parallel to the x axis in the + range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values, + a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. + Will not apply shear by default. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, + but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. + fill (sequence or number): Pixel fill value for the area outside the transformed + image. Default is ``0``. If given a number, the value is used for all bands respectively. + fillcolor (sequence or number, optional): + .. warning:: + This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``fill`` instead. + resample (int, optional): + .. warning:: + This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``interpolation`` + instead. + center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner. + Default is the center of the image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + + def __init__( + self, + degrees, + translate=None, + scale=None, + shear=None, + interpolation=InterpolationMode.NEAREST, + fill=0, + fillcolor=None, + resample=None, + center=None, + ): + super(RandomAffine, self).__init__() + if resample is not None: + warnings.warn( + "The parameter 'resample' is deprecated since 0.12 and will be removed in 0.14. " + "Please use 'interpolation' instead.") + interpolation = _interpolation_modes_from_int(resample) + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " + "Please use InterpolationMode enum.") + interpolation = _interpolation_modes_from_int(interpolation) + + if fillcolor is not None: + warnings.warn( + "The parameter 'fillcolor' is deprecated since 0.12 and will be removed in 0.14. " + "Please use 'fill' instead.") + fill = fillcolor + + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) + + if translate is not None: + _check_sequence_input(translate, "translate", req_sizes=(2, )) + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError( + "translation values should be between 0 and 1") + self.translate = translate + + if scale is not None: + _check_sequence_input(scale, "scale", req_sizes=(2, )) + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4)) + else: + self.shear = shear + + self.resample = self.interpolation = interpolation + + if fill is None: + fill = 0 + elif not isinstance(fill, (Sequence, numbers.Number)): + raise TypeError("Fill should be either a sequence or a number.") + + self.fillcolor = self.fill = fill + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2, )) + + self.center = center + + @staticmethod + def get_params( + degrees: List[float], + translate: Optional[List[float]], + scale_ranges: Optional[List[float]], + shears: Optional[List[float]], + img_size: List[int], + ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]: + """Get parameters for affine transformation + + Returns: + params to be passed to the affine transformation + """ + angle = float( + paddle.empty([1]).uniform_(float(degrees[0]), float(degrees[1]))) + if translate is not None: + max_dx = float(translate[0] * img_size[0]) + max_dy = float(translate[1] * img_size[1]) + tx = int(float(paddle.empty([1]).uniform_(-max_dx, max_dx))) + ty = int(float(paddle.empty([1]).uniform_(-max_dy, max_dy))) + translations = (tx, ty) + else: + translations = (0, 0) + + if scale_ranges is not None: + scale = float( + paddle.empty([1]).uniform_(scale_ranges[0], scale_ranges[1])) + else: + scale = 1.0 + + shear_x = shear_y = 0.0 + if shears is not None: + shear_x = float(paddle.empty([1]).uniform_(shears[0], shears[1])) + if len(shears) == 4: + shear_y = float( + paddle.empty([1]).uniform_(shears[2], shears[3])) + + shear = (shear_x, shear_y) + + return angle, translations, scale, shear + + def forward(self, img): + fill = self.fill + channels, height, width = img.shape[1], img.shape[2], img.shape[3] + if isinstance(fill, (int, float)): + fill = [float(fill)] * channels + else: + fill = [float(f) for f in fill] + + img_size = [width, height] # flip for keeping BC on get_params call + + ret = self.get_params(self.degrees, self.translate, self.scale, + self.shear, img_size) + + return affine(img, + *ret, + interpolation=self.interpolation, + fill=fill, + center=self.center) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}(degrees={self.degrees}" + s += f", translate={self.translate}" if self.translate is not None else "" + s += f", scale={self.scale}" if self.scale is not None else "" + s += f", shear={self.shear}" if self.shear is not None else "" + s += f", interpolation={self.interpolation.value}" if self.interpolation != InterpolationMode.NEAREST else "" + s += f", fill={self.fill}" if self.fill != 0 else "" + s += f", center={self.center}" if self.center is not None else "" + s += ")" + + return s diff --git a/paddlenlp/transformers/guided_diffusion_utils/unet.py b/paddlenlp/transformers/guided_diffusion_utils/unet.py new file mode 100755 index 000000000000..67f8697830a5 --- /dev/null +++ b/paddlenlp/transformers/guided_diffusion_utils/unet.py @@ -0,0 +1,723 @@ +''' +This code is rewritten by Paddle based on +https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py +''' +import math +from abc import abstractmethod + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class GroupNorm32(nn.GroupNorm): + + def forward(self, x): + return super().forward(x) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1D(*args, **kwargs) + elif dims == 2: + return nn.Conv2D(*args, **kwargs) + elif dims == 3: + return nn.Conv3D(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1D(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2D(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3D(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(axis=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = paddle.exp(-math.log(max_period) * + paddle.arange(start=0, end=half, dtype=paddle.float32) / + half) + args = paddle.cast(timesteps[:, None], 'float32') * freqs[None] + embedding = paddle.concat([paddle.cos(args), paddle.sin(args)], axis=-1) + if dim % 2: + embedding = paddle.concat( + [embedding, paddle.zeros_like(embedding[:, :1])], axis=-1) + return embedding + + +class AttentionPool2d(nn.Layer): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = self.create_parameter( + paddle.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c = x.shape[:2] + x = paddle.reshape(x, [b, c, -1]) + x = paddle.concat([x.mean(dim=-1, keepdim=True), x], + axis=-1) # NC(HW+1) + x = x + paddle.cast(self.positional_embedding[None, :, :], + x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Layer): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +class Upsample(nn.Layer): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, + self.channels, + self.out_channels, + 3, + padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode="nearest") + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Layer): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd(dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=1) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.Silu(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.Silu(), + linear( + emb_channels, + 2 * self.out_channels + if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.Silu(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, + self.out_channels, + self.out_channels, + 3, + padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, + channels, + self.out_channels, + 3, + padding=1) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb) + emb_out = paddle.cast(emb_out, h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = paddle.chunk(emb_out, 2, axis=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Layer): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + b, c, *spatial = x.shape + x = paddle.reshape(x, [b, c, -1]) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return paddle.reshape(x + h, [b, c, *spatial]) + + +class QKVAttentionLegacy(nn.Layer): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = paddle.reshape( + qkv, [bs * self.n_heads, ch * 3, length]).split(3, axis=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = paddle.einsum( + "bct,bcs->bts", q * scale, + k * scale) # More stable with f16 than dividing afterwards + weight = paddle.cast(F.softmax(paddle.cast(weight, 'float32'), axis=-1), + weight.dtype) + a = paddle.einsum("bts,bcs->bct", weight, v) + + return paddle.reshape(a, [bs, -1, length]) + + +class QKVAttention(nn.Layer): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, axis=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = paddle.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = paddle.cast(F.softmax(paddle.cast(weight, 'float32'), axis=-1), + weight.dtype) + a = paddle.einsum("bts,bcs->bct", weight, + v.reshape(bs * self.n_heads, ch, length)) + return paddle.reshape(a, [bs, -1, length]) + + +class UNetModel(nn.Layer): + """ + The full UNet model with attention and timestep embedding. + + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.dtype = paddle.float16 if use_fp16 else paddle.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.Silu(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + ch = input_ch = int(channel_mult[0] * model_channels) + self.input_blocks = nn.LayerList([ + TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, + padding=1)) + ]) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=int(mult * model_channels), + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(mult * model_channels) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + )) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) if resblock_updown else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch))) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.LayerList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=int(model_channels * mult), + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(model_channels * mult) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + )) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) if resblock_updown else Upsample( + ch, conv_resample, dims=dims, out_channels=out_ch)) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.Silu(), + zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), + ) + + def forward(self, x, timesteps, y=None): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + hs = [] + emb = self.time_embed(timestep_embedding(timesteps, + self.model_channels)) + if self.num_classes is not None: + assert y.shape == (x.shape[0], ) + emb = emb + self.label_emb(y) + + h = paddle.cast(x, self.dtype) + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + h = paddle.concat([h, hs.pop()], axis=1) + h = module(h, emb) + return self.out(h) + + +class SuperResModel(UNetModel): + """ + A UNetModel that performs super-resolution. + + Expects an extra kwarg `low_res` to condition on a low-resolution image. + """ + + def __init__(self, image_size, in_channels, *args, **kwargs): + super().__init__(image_size, in_channels * 2, *args, **kwargs) + + def forward(self, x, timesteps, low_res=None, **kwargs): + _, _, new_height, new_width = x.shape + upsampled = F.interpolate(low_res, (new_height, new_width), + mode="bilinear") + x = paddle.concat([x, upsampled], axis=1) + return super().forward(x, timesteps, **kwargs) diff --git a/paddlenlp/transformers/guided_diffusion_utils/utils.py b/paddlenlp/transformers/guided_diffusion_utils/utils.py new file mode 100644 index 000000000000..e4c0478dbfc0 --- /dev/null +++ b/paddlenlp/transformers/guided_diffusion_utils/utils.py @@ -0,0 +1,442 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +''' +This code is rewritten by Paddle based on Jina-ai/discoart. +https://github.com/jina-ai/discoart/blob/main/discoart/runner.py +''' +import paddle +import gc +import random +import numpy as np +import paddle +import paddle.vision.transforms as T + +from PIL import Image +from pathlib import Path +from paddle.utils import try_import +from .losses import range_loss, spherical_dist_loss, tv_loss +from .make_cutouts import MakeCutoutsDango +from .sec_diff import alpha_sigma_to_t +from .transforms import Normalize +from .perlin_noises import create_perlin_noise, regen_perlin +import random +from ..image_utils import load_image + +__all__ = ["DiscoDiffusionMixin"] + + +def set_seed(seed): + np.random.seed(seed) + random.seed(seed) + paddle.seed(seed) + + +class DiscoDiffusionMixin: + + def disco_diffusion_generate(self, + target_text_embeds, + init_image=None, + output_dir='outputs/', + width_height=[1280, 768], + skip_steps=0, + cut_ic_pow=1, + init_scale=1000, + clip_guidance_scale=5000, + tv_scale=0, + range_scale=0, + sat_scale=0, + cutn_batches=4, + perlin_init=False, + perlin_mode='mixed', + seed=None, + eta=0.8, + clamp_grad=True, + clamp_max=0.05, + cut_overview='[12]*400+[4]*600', + cut_innercut='[4]*400+[12]*600', + cut_icgray_p='[0.2]*400+[0]*600', + save_rate=10, + n_batches=1, + batch_name="", + use_secondary_model=True, + randomize_class=True, + clip_denoised=False, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, + 0.27577711]): + r''' + The DiffusionMixin diffusion_generate method. + + Args: + init_image (Path, optional): + Recall that in the image sequence above, the first image shown is just noise. If an init_image + is provided, diffusion will replace the noise with the init_image as its starting state. To use + an init_image, upload the image to the Colab instance or your Google Drive, and enter the full + image path here. If using an init_image, you may need to increase skip_steps to ~ 50% of total + steps to retain the character of the init. See skip_steps above for further discussion. + Default to `None`. + output_dir (Path, optional): + Output directory. + Default to `disco_diffusion_clip_vitb32_out`. + width_height (List[int, int], optional): + Desired final image size, in pixels. You can have a square, wide, or tall image, but each edge + length should be set to a multiple of 64px, and a minimum of 512px on the default CLIP model setting. + If you forget to use multiples of 64px in your dimensions, DD will adjust the dimensions of your + image to make it so. + Default to `[1280, 768]`. + skip_steps (int, optional): + Consider the chart shown here. Noise scheduling (denoise strength) starts very high and progressively + gets lower and lower as diffusion steps progress. The noise levels in the first few steps are very high, + so images change dramatically in early steps.As DD moves along the curve, noise levels (and thus the + amount an image changes per step) declines, and image coherence from one step to the next increases. + The first few steps of denoising are often so dramatic that some steps (maybe 10-15% of total) can be + skipped without affecting the final image. You can experiment with this as a way to cut render times. + If you skip too many steps, however, the remaining noise may not be high enough to generate new content, + and thus may not have time left to finish an image satisfactorily.Also, depending on your other settings, + you may need to skip steps to prevent CLIP from overshooting your goal, resulting in blown out colors + (hyper saturated, solid white, or solid black regions) or otherwise poor image quality. Consider that + the denoising process is at its strongest in the early steps, so skipping steps can sometimes mitigate + other problems.Lastly, if using an init_image, you will need to skip ~50% of the diffusion steps to retain + the shapes in the original init image. However, if you're using an init_image, you can also adjust + skip_steps up or down for creative reasons. With low skip_steps you can get a result "inspired by" + the init_image which will retain the colors and rough layout and shapes but look quite different. + With high skip_steps you can preserve most of the init_image contents and just do fine tuning of the texture. + Default to `0`. + steps: + When creating an image, the denoising curve is subdivided into steps for processing. Each step (or iteration) + involves the AI looking at subsets of the image called 'cuts' and calculating the 'direction' the image + should be guided to be more like the prompt. Then it adjusts the image with the help of the diffusion denoiser, + and moves to the next step.Increasing steps will provide more opportunities for the AI to adjust the image, + and each adjustment will be smaller, and thus will yield a more precise, detailed image. Increasing steps + comes at the expense of longer render times. Also, while increasing steps should generally increase image + quality, there is a diminishing return on additional steps beyond 250 - 500 steps. However, some intricate + images can take 1000, 2000, or more steps. It is really up to the user. Just know that the render time is + directly related to the number of steps, and many other parameters have a major impact on image quality, without + costing additional time. + cut_ic_pow (int, optional): + This sets the size of the border used for inner cuts. High cut_ic_pow values have larger borders, and + therefore the cuts themselves will be smaller and provide finer details. If you have too many or too-small + inner cuts, you may lose overall image coherency and/or it may cause an undesirable 'mosaic' effect. + Low cut_ic_pow values will allow the inner cuts to be larger, helping image coherency while still helping + with some details. + Default to `1`. + init_scale (int, optional): + This controls how strongly CLIP will try to match the init_image provided. This is balanced against the + clip_guidance_scale (CGS) above. Too much init scale, and the image won't change much during diffusion. + Too much CGS and the init image will be lost. + Default to `1000`. + clip_guidance_scale (int, optional): + CGS is one of the most important parameters you will use. It tells DD how strongly you want CLIP to move + toward your prompt each timestep. Higher is generally better, but if CGS is too strong it will overshoot + the goal and distort the image. So a happy medium is needed, and it takes experience to learn how to adjust + CGS. Note that this parameter generally scales with image dimensions. In other words, if you increase your + total dimensions by 50% (e.g. a change from 512 x 512 to 512 x 768), then to maintain the same effect on the + image, you'd want to increase clip_guidance_scale from 5000 to 7500. Of the basic settings, clip_guidance_scale, + steps and skip_steps are the most important contributors to image quality, so learn them well. + Default to `5000`. + tv_scale (int, optional): + Total variance denoising. Optional, set to zero to turn off. Controls smoothness of final output. If used, + tv_scale will try to smooth out your final image to reduce overall noise. If your image is too 'crunchy', + increase tv_scale. TV denoising is good at preserving edges while smoothing away noise in flat regions. + See https://en.wikipedia.org/wiki/Total_variation_denoising + Default to `0`. + range_scale (int, optional): + Optional, set to zero to turn off. Used for adjustment of color contrast. Lower range_scale will increase + contrast. Very low numbers create a reduced color palette, resulting in more vibrant or poster-like images. + Higher range_scale will reduce contrast, for more muted images. + Default to `0`. + sat_scale (int, optional): + Saturation scale. Optional, set to zero to turn off. If used, sat_scale will help mitigate oversaturation. + If your image is too saturated, increase sat_scale to reduce the saturation. + Default to `0`. + cutn_batches (int, optional): + Each iteration, the AI cuts the image into smaller pieces known as cuts, and compares each cut to the prompt + to decide how to guide the next diffusion step. More cuts can generally lead to better images, since DD has + more chances to fine-tune the image precision in each timestep. Additional cuts are memory intensive, however, + and if DD tries to evaluate too many cuts at once, it can run out of memory. You can use cutn_batches to increase + cuts per timestep without increasing memory usage. At the default settings, DD is scheduled to do 16 cuts per + timestep. If cutn_batches is set to 1, there will indeed only be 16 cuts total per timestep. However, if + cutn_batches is increased to 4, DD will do 64 cuts total in each timestep, divided into 4 sequential batches + of 16 cuts each. Because the cuts are being evaluated only 16 at a time, DD uses the memory required for only 16 cuts, + but gives you the quality benefit of 64 cuts. The tradeoff, of course, is that this will take ~4 times as long to + render each image.So, (scheduled cuts) x (cutn_batches) = (total cuts per timestep). Increasing cutn_batches will + increase render times, however, as the work is being done sequentially. DD's default cut schedule is a good place + to start, but the cut schedule can be adjusted in the Cutn Scheduling section, explained below. + Default to `4`. + perlin_init (bool, optional): + Normally, DD will use an image filled with random noise as a starting point for the diffusion curve. + If perlin_init is selected, DD will instead use a Perlin noise model as an initial state. Perlin has very + interesting characteristics, distinct from random noise, so it's worth experimenting with this for your projects. + Beyond perlin, you can, of course, generate your own noise images (such as with GIMP, etc) and use them as an + init_image (without skipping steps). Choosing perlin_init does not affect the actual diffusion process, just the + starting point for the diffusion. Please note that selecting a perlin_init will replace and override any init_image + you may have specified. Further, because the 2D, 3D and video animation systems all rely on the init_image system, + if you enable Perlin while using animation modes, the perlin_init will jump in front of any previous image or video + input, and DD will NOT give you the expected sequence of coherent images. All of that said, using Perlin and + animation modes together do make a very colorful rainbow effect, which can be used creatively. + Default to `False`. + perlin_mode (str, optional): + sets type of Perlin noise: colored, gray, or a mix of both, giving you additional options for noise types. Experiment + to see what these do in your projects. + Default to `mixed`. + seed (int, optional): + Deep in the diffusion code, there is a random number seed which is used as the basis for determining the initial + state of the diffusion. By default, this is random, but you can also specify your own seed. This is useful if you like a + particular result and would like to run more iterations that will be similar. After each run, the actual seed value used will be + reported in the parameters report, and can be reused if desired by entering seed # here. If a specific numerical seed is used + repeatedly, the resulting images will be quite similar but not identical. + Default to `None`. + eta (float, optional): + Eta (greek letter η) is a diffusion model variable that mixes in a random amount of scaled noise into each timestep. + 0 is no noise, 1.0 is more noise. As with most DD parameters, you can go below zero for eta, but it may give you + unpredictable results. The steps parameter has a close relationship with the eta parameter. If you set eta to 0, + then you can get decent output with only 50-75 steps. Setting eta to 1.0 favors higher step counts, ideally around + 250 and up. eta has a subtle, unpredictable effect on image, so you'll need to experiment to see how this affects your projects. + Default to `0.8`. + clamp_grad (bool, optional): + As I understand it, clamp_grad is an internal limiter that stops DD from producing extreme results. Try your images with and without + clamp_grad. If the image changes drastically with clamp_grad turned off, it probably means your clip_guidance_scale is too high and + should be reduced. + Default to `True`. + clamp_max (float, optional): + Sets the value of the clamp_grad limitation. Default is 0.05, providing for smoother, more muted coloration in images, but setting + higher values (0.15-0.3) can provide interesting contrast and vibrancy. + Default to `0.05`. + cut_overview (str, optional): + The schedule of overview cuts. + Default to `'[12]*400+[4]*600'`. + cut_innercut (str, optional): + The schedule of inner cuts. + Default to `'[4]*400+[12]*600'`. + cut_icgray_p (str, optional): + This sets the size of the border used for inner cuts. High cut_ic_pow values have larger borders, and therefore the cuts + themselves will be smaller and provide finer details. If you have too many or too-small inner cuts, you may lose overall + image coherency and/or it may cause an undesirable 'mosaic' effect. Low cut_ic_pow values will allow the inner cuts to be + larger, helping image coherency while still helping with some details. + Default to `'[0.2]*400+[0]*600'`. + save_rate (int, optional): + During a diffusion run, you can monitor the progress of each image being created with this variable. If display_rate is set + to 50, DD will show you the in-progress image every 50 timesteps. Setting this to a lower value, like 5 or 10, is a good way + to get an early peek at where your image is heading. If you don't like the progression, just interrupt execution, change some + settings, and re-run. If you are planning a long, unmonitored batch, it's better to set display_rate equal to steps, because + displaying interim images does slow Colab down slightly. + Default to `10`. + n_batches (int, optional): + This variable sets the number of still images you want DD to create. If you are using an animation mode (see below for details) + DD will ignore n_batches and create a single set of animated frames based on the animation settings. + Default to `1`. + batch_name (str, optional): + The name of the batch, the batch id will be named as "progress-[batch_name]-seed-[range(n_batches)]-[save_rate]". To avoid your + artworks be overridden by other users, please use a unique name. + Default to `''`. + use_secondary_model (bool, optional): + Whether or not use secondary model. + Default to `True`. + randomize_class (bool, optional): + Random class. + Default to `True`. + clip_denoised (bool, optional): + Clip denoised. + Default to `False`. + ''' + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + batch_size = 1 + normalize = Normalize( + mean=image_mean, + std=image_std, + ) + side_x = (width_height[0] // 64) * 64 + side_y = (width_height[1] // 64) * 64 + cut_overview = eval(cut_overview) + cut_innercut = eval(cut_innercut) + cut_icgray_p = eval(cut_icgray_p) + + seed = seed or random.randint(0, 2**32) + set_seed(seed) + + init = None + if init_image: + d = load_image(init_image) + init = T.to_tensor(d).unsqueeze(0) * 2 - 1 + + if perlin_init: + if perlin_mode == 'color': + init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)], + 1, 1, False, side_y, side_x) + init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], + 4, 4, False, side_y, side_x) + elif perlin_mode == 'gray': + init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)], + 1, 1, True, side_y, side_x) + init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], + 4, 4, True, side_y, side_x) + else: + init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)], + 1, 1, False, side_y, side_x) + init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], + 4, 4, True, side_y, side_x) + init = (T.to_tensor(init).add(T.to_tensor(init2)).divide( + paddle.to_tensor(2.0)).unsqueeze(0) * 2 - 1) + del init2 + + if init is not None and init_scale: + lpips = try_import("paddle_lpips") + lpips_model = lpips.LPIPS(net='vgg') + lpips_model.eval() + for parameter in lpips_model.parameters(): + parameter.stop_gradient = True + + cur_t = None + + def cond_fn(x, t, y=None): + x_is_NaN = False + n = x.shape[0] + x = paddle.to_tensor(x.detach(), dtype='float32') + x.stop_gradient = False + if use_secondary_model: + alpha = paddle.to_tensor( + self.diffusion.sqrt_alphas_cumprod[cur_t], dtype='float32') + sigma = paddle.to_tensor( + self.diffusion.sqrt_one_minus_alphas_cumprod[cur_t], + dtype='float32') + cosine_t = alpha_sigma_to_t(alpha, sigma) + cosine_t = paddle.tile( + paddle.to_tensor(cosine_t.detach().cpu().numpy()), [n]) + cosine_t.stop_gradient = False + out = self.secondary_model(x, cosine_t).pred + fac = self.diffusion.sqrt_one_minus_alphas_cumprod[cur_t] + x_in_d = out * fac + x * (1 - fac) + x_in = x_in_d.detach() + x_in.stop_gradient = False + x_in_grad = paddle.zeros_like(x_in, dtype='float32') + else: + t = paddle.ones([n], dtype='int64') * cur_t + out = self.diffusion.p_mean_variance(self.unet_model, + x, + t, + clip_denoised=False, + model_kwargs={'y': y}) + fac = self.diffusion.sqrt_one_minus_alphas_cumprod[cur_t] + x_in_d = out['pred_xstart'].astype("float32") * fac + x * (1 - + fac) + x_in = x_in_d.detach() + x_in.stop_gradient = False + x_in_grad = paddle.zeros_like(x_in, dtype='float32') + + for _ in range(cutn_batches): + t_int = ( + int(t.item()) + 1 + ) # errors on last step without +1, need to find source + # when using SLIP Base model the dimensions need to be hard coded to avoid AttributeError: 'VisionTransformer' object has no attribute 'input_resolution' + try: + input_resolution = self.vision_model.input_resolution + except: + input_resolution = 224 + + cuts = MakeCutoutsDango( + input_resolution, + Overview=cut_overview[1000 - t_int], + InnerCrop=cut_innercut[1000 - t_int], + IC_Size_Pow=cut_ic_pow, + IC_Grey_P=cut_icgray_p[1000 - t_int], + ) + clip_in = normalize( + cuts( + x_in.add(paddle.to_tensor(1.0)).divide( + paddle.to_tensor(2.0)))) + image_embeds = self.get_image_features(clip_in) + + dists = spherical_dist_loss( + image_embeds.unsqueeze(1), + target_text_embeds.unsqueeze(0), + ) + + dists = dists.reshape([ + cut_overview[1000 - t_int] + cut_innercut[1000 - t_int], + n, + -1, + ]) + losses = dists.sum(2).mean(0) + x_in_grad += ( + paddle.grad(losses.sum() * clip_guidance_scale, x_in)[0] / + cutn_batches) + tv_losses = tv_loss(x_in) + range_losses = range_loss(x_in) + sat_losses = paddle.abs(x_in - x_in.clip(min=-1, max=1)).mean() + loss = (tv_losses.sum() * tv_scale + + range_losses.sum() * range_scale + + sat_losses.sum() * sat_scale) + if init is not None and init_scale: + init_losses = lpips_model(x_in, init) + loss = loss + init_losses.sum() * init_scale + x_in_grad += paddle.grad(loss, x_in)[0] + if not paddle.isnan(x_in_grad).any(): + grad = -paddle.grad(x_in_d, x, x_in_grad)[0] + else: + x_is_NaN = True + grad = paddle.zeros_like(x) + if clamp_grad and not x_is_NaN: + magnitude = grad.square().mean().sqrt() + return (grad * magnitude.clip(max=clamp_max) / magnitude) + return grad + + # we use ddim sample + sample_fn = self.diffusion.ddim_sample_loop_progressive + + da_batches = [] + + # process output file name + output_filename_list = ["progress"] + if batch_name != "": + output_filename_list.append(batch_name) + if seed is not None: + output_filename_list.append(str(seed)) + output_filename_prefix = "-".join(output_filename_list) + + for _nb in range(n_batches): + gc.collect() + paddle.device.cuda.empty_cache() + cur_t = self.diffusion.num_timesteps - skip_steps - 1 + + if perlin_init: + init = regen_perlin(perlin_mode, side_y, side_x, batch_size) + + samples = sample_fn( + self.unet_model, + (batch_size, 3, side_y, side_x), + clip_denoised=clip_denoised, + model_kwargs={}, + cond_fn=cond_fn, + progress=True, + skip_timesteps=skip_steps, + init_image=init, + randomize_class=randomize_class, + eta=eta, + ) + + for j, sample in enumerate(samples): + cur_t -= 1 + if j % save_rate == 0 or cur_t == -1: + for b, image in enumerate(sample['pred_xstart']): + image = (((image + 1) / 2).clip( + 0, 1).squeeze().transpose([1, 2, 0]).numpy() * + 255).astype("uint8") + image = Image.fromarray(image) + image.save(output_dir / + f'{output_filename_prefix}-{_nb}-{j}.png') + if cur_t == -1: + da_batches.append(image) + + return da_batches diff --git a/paddlenlp/transformers/image_utils.py b/paddlenlp/transformers/image_utils.py new file mode 100644 index 000000000000..4eb66338ef51 --- /dev/null +++ b/paddlenlp/transformers/image_utils.py @@ -0,0 +1,404 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List, Union +import paddle +import numpy as np +from PIL import Image +import PIL.Image +import PIL.ImageOps + +import requests + +IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] +IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] +IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5] +IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5] + +ImageInput = Union[PIL.Image.Image, np.ndarray, "paddle.Tensor", + List[PIL.Image.Image], List[np.ndarray], + List["paddle.Tensor"] # noqa + ] + + +def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image": + """ + Loads `image` to a PIL Image. + Args: + image (`str` or `PIL.Image.Image`): + The image to convert to the PIL Image format. + Returns: + `PIL.Image.Image`: A PIL Image. + """ + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + # We need to actually check for a real protocol, otherwise it's impossible to use a local file + # like http_huggingface_co.png + image = PIL.Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + raise ValueError( + f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" + ) + elif isinstance(image, PIL.Image.Image): + image = image + else: + raise ValueError( + "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." + ) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + + +class ImageFeatureExtractionMixin: + """ + Mixin that contain utilities for preparing image features. + """ + + def _ensure_format_supported(self, image): + if not isinstance( + image, + (PIL.Image.Image, np.ndarray)) and not paddle.is_tensor(image): + raise ValueError( + f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and " + "`paddle.Tensor` are.") + + def to_pil_image(self, image, rescale=None): + """ + Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if + needed. + Args: + image (`PIL.Image.Image` or `numpy.ndarray` or `paddle.Tensor`): + The image to convert to the PIL Image format. + rescale (`bool`, *optional*): + Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will + default to `True` if the image type is a floating type, `False` otherwise. + """ + self._ensure_format_supported(image) + + if paddle.is_tensor(image): + image = image.numpy() + + if isinstance(image, np.ndarray): + if rescale is None: + # rescale default to the array being of floating type. + rescale = isinstance(image.flat[0], np.floating) + # If the channel as been moved to first dim, we put it back at the end. + if image.ndim == 3 and image.shape[0] in [1, 3]: + image = image.transpose([1, 2, 0]) + if rescale: + image = image * 255 + image = image.astype(np.uint8) + return PIL.Image.fromarray(image) + return image + + def convert_rgb(self, image): + """ + Converts `PIL.Image.Image` to RGB format. + Args: + image (`PIL.Image.Image`): + The image to convert. + """ + self._ensure_format_supported(image) + if not isinstance(image, PIL.Image.Image): + return image + + return image.convert("RGB") + + def to_numpy_array(self, image, rescale=None, channel_first=True): + """ + Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first + dimension. + Args: + image (`PIL.Image.Image` or `np.ndarray` or `paddle.Tensor`): + The image to convert to a NumPy array. + rescale (`bool`, *optional*): + Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will + default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise. + channel_first (`bool`, *optional*, defaults to `True`): + Whether or not to permute the dimensions of the image to put the channel dimension first. + """ + self._ensure_format_supported(image) + + if isinstance(image, PIL.Image.Image): + image = np.array(image) + + if paddle.is_tensor(image): + image = image.numpy() + + if rescale is None: + rescale = isinstance(image.flat[0], np.integer) + + if rescale: + image = image.astype(np.float32) / 255.0 + + if channel_first and image.ndim == 3: + image = image.transpose([2, 0, 1]) + + return image + + def expand_dims(self, image): + """ + Expands 2-dimensional `image` to 3 dimensions. + Args: + image (`PIL.Image.Image` or `np.ndarray` or `paddle.Tensor`): + The image to expand. + """ + self._ensure_format_supported(image) + + # Do nothing if PIL image + if isinstance(image, PIL.Image.Image): + return image + + if paddle.is_tensor(image): + image = image.unsqueeze(0) + else: + image = np.expand_dims(image, axis=0) + return image + + def normalize(self, image, mean, std): + """ + Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array + if it's a PIL Image. + Args: + image (`PIL.Image.Image` or `np.ndarray` or `paddle.Tensor`): + The image to normalize. + mean (`List[float]` or `np.ndarray` or `paddle.Tensor`): + The mean (per channel) to use for normalization. + std (`List[float]` or `np.ndarray` or `paddle.Tensor`): + The standard deviation (per channel) to use for normalization. + """ + self._ensure_format_supported(image) + + if isinstance(image, PIL.Image.Image): + image = self.to_numpy_array(image) + + if isinstance(image, np.ndarray): + if not isinstance(mean, np.ndarray): + mean = np.array(mean).astype(image.dtype) + if not isinstance(std, np.ndarray): + std = np.array(std).astype(image.dtype) + elif paddle.is_tensor(image): + import paddle + + if not isinstance(mean, paddle.Tensor): + mean = paddle.to_tensor(mean).astype(image.dtype) + if not isinstance(std, paddle.Tensor): + std = paddle.to_tensor(std).astype(image.dtype) + + if image.ndim == 3 and image.shape[0] in [1, 3]: + return (image - mean[:, None, None]) / std[:, None, None] + else: + return (image - mean) / std + + def resize(self, + image, + size, + resample=Image.BILINEAR, + default_to_square=True, + max_size=None): + """ + Resizes `image`. Enforces conversion of input to PIL.Image. + Args: + image (`PIL.Image.Image` or `np.ndarray` or `paddle.Tensor`): + The image to resize. + size (`int` or `Tuple[int, int]`): + The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be + matched to this. + If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If + `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to + this number. i.e, if height > width, then image will be rescaled to (size * height / width, size). + resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`): + The filter to user for resampling. + default_to_square (`bool`, *optional*, defaults to `True`): + How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a + square (`size`,`size`). If set to `False`, will replicate + [`paddle.vision.transforms.Resize`](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/vision/transforms/Resize_cn.html#resize) + with support for resizing only the smallest edge and providing an optional `max_size`. + max_size (`int`, *optional*, defaults to `None`): + The maximum allowed for the longer edge of the resized image: if the longer edge of the image is + greater than `max_size` after being resized according to `size`, then the image is resized again so + that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller + edge may be shorter than `size`. Only used if `default_to_square` is `False`. + Returns: + image: A resized `PIL.Image.Image`. + """ + self._ensure_format_supported(image) + + if not isinstance(image, PIL.Image.Image): + image = self.to_pil_image(image) + + if isinstance(size, list): + size = tuple(size) + + if isinstance(size, int) or len(size) == 1: + if default_to_square: + size = (size, size) if isinstance(size, int) else (size[0], + size[0]) + else: + width, height = image.size + # specified size only for the smallest edge + short, long = (width, height) if width <= height else (height, + width) + requested_new_short = size if isinstance(size, int) else size[0] + + if short == requested_new_short: + return image + + new_short, new_long = requested_new_short, int( + requested_new_short * long / short) + + if max_size is not None: + if max_size <= requested_new_short: + raise ValueError( + f"max_size = {max_size} must be strictly greater than the requested " + f"size for the smaller edge size = {size}") + if new_long > max_size: + new_short, new_long = int(max_size * new_short / + new_long), max_size + + size = (new_short, new_long) if width <= height else (new_long, + new_short) + + return image.resize(size, resample=resample) + + def center_crop(self, image, size): + """ + Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the + size given, it will be padded (so the returned result has the size asked). + Args: + image (`PIL.Image.Image` or `np.ndarray` or `paddle.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)): + The image to resize. + size (`int` or `Tuple[int, int]`): + The size to which crop the image. + Returns: + new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `paddle.Tensor` of shape: (n_channels, + height, width). + """ + self._ensure_format_supported(image) + + if not isinstance(size, tuple): + size = (size, size) + + # PIL Image.size is (width, height) but NumPy array and paddle Tensors have (height, width) + if paddle.is_tensor(image) or isinstance(image, np.ndarray): + if image.ndim == 2: + image = self.expand_dims(image) + image_shape = image.shape[1:] if image.shape[0] in [ + 1, 3 + ] else image.shape[:2] + else: + image_shape = (image.size[1], image.size[0]) + + top = (image_shape[0] - size[0]) // 2 + bottom = top + size[ + 0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result. + left = (image_shape[1] - size[1]) // 2 + right = left + size[ + 1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result. + + # For PIL Images we have a method to crop directly. + if isinstance(image, PIL.Image.Image): + return image.crop((left, top, right, bottom)) + + # Check if image is in (n_channels, height, width) or (height, width, n_channels) format + channel_first = True if image.shape[0] in [1, 3] else False + + # Transpose (height, width, n_channels) format images + if not channel_first: + if isinstance(image, np.ndarray): + image = image.transpose([2, 0, 1]) + if paddle.is_tensor(image): + image = image.transpose([2, 0, 1]) + + # Check if cropped area is within image boundaries + if top >= 0 and bottom <= image_shape[ + 0] and left >= 0 and right <= image_shape[1]: + return image[..., top:bottom, left:right] + + # Otherwise, we may need to pad if the image is too small. Oh joy... + new_shape = image.shape[:-2] + (max( + size[0], image_shape[0]), max(size[1], image_shape[1])) + if isinstance(image, np.ndarray): + new_image = np.zeros_like(image, shape=new_shape) + elif paddle.is_tensor(image): + new_image = paddle.zeros(new_shape, dtype=image.dtype) + + top_pad = (new_shape[-2] - image_shape[0]) // 2 + bottom_pad = top_pad + image_shape[0] + left_pad = (new_shape[-1] - image_shape[1]) // 2 + right_pad = left_pad + image_shape[1] + new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image + + top += top_pad + bottom += top_pad + left += left_pad + right += left_pad + + new_image = new_image[..., + max(0, top):min(new_image.shape[-2], bottom), + max(0, left):min(new_image.shape[-1], right)] + + return new_image + + def flip_channel_order(self, image): + """ + Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of + `image` to a NumPy array if it's a PIL Image. + Args: + image (`PIL.Image.Image` or `np.ndarray` or `paddle.Tensor`): + The image whose color channels to flip. If `np.ndarray` or `paddle.Tensor`, the channel dimension should + be first. + """ + self._ensure_format_supported(image) + + if isinstance(image, PIL.Image.Image): + image = self.to_numpy_array(image) + + return image[::-1, :, :] + + def rotate(self, + image, + angle, + resample=Image.NEAREST, + expand=0, + center=None, + translate=None, + fillcolor=None): + """ + Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees + counter clockwise around its centre. + Args: + image (`PIL.Image.Image` or `np.ndarray` or `paddle.Tensor`): + The image to rotate. If `np.ndarray` or `paddle.Tensor`, will be converted to `PIL.Image.Image` before + rotating. + Returns: + image: A rotated `PIL.Image.Image`. + """ + self._ensure_format_supported(image) + + if not isinstance(image, PIL.Image.Image): + image = self.to_pil_image(image) + + return image.rotate(angle, + resample=resample, + expand=expand, + center=center, + translate=translate, + fillcolor=fillcolor) diff --git a/paddlenlp/transformers/stable_diffusion_utils/__init__.py b/paddlenlp/transformers/stable_diffusion_utils/__init__.py new file mode 100644 index 000000000000..e03311d4453f --- /dev/null +++ b/paddlenlp/transformers/stable_diffusion_utils/__init__.py @@ -0,0 +1,5 @@ +from .unet_2d_condition import UNet2DConditionModel +from .vae import AutoencoderKL +from .schedulers import (LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, + SchedulerMixin) +from .utils import StableDiffusionMixin diff --git a/paddlenlp/transformers/stable_diffusion_utils/attention.py b/paddlenlp/transformers/stable_diffusion_utils/attention.py new file mode 100644 index 000000000000..caad51663b36 --- /dev/null +++ b/paddlenlp/transformers/stable_diffusion_utils/attention.py @@ -0,0 +1,299 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np + + +def finfo(dtype): + if dtype == paddle.float32: + return np.finfo(np.float32) + if dtype == paddle.float16: + return np.finfo(np.float16) + if dtype == paddle.float64: + return np.finfo(np.float64) + + +class AttentionBlock(nn.Layer): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Uses three q, k, v linear layers to compute attention + """ + + def __init__( + self, + channels, + num_head_channels=None, + num_groups=32, + rescale_output_factor=1.0, + eps=1e-5, + ): + super().__init__() + self.channels = channels + + self.num_heads = (channels // num_head_channels + if num_head_channels is not None else 1) + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm(num_channels=channels, + num_groups=num_groups, + epsilon=eps) + + # define q,k,v as linear layers + self.query = nn.Linear(channels, channels) + self.key = nn.Linear(channels, channels) + self.value = nn.Linear(channels, channels) + + self.rescale_output_factor = rescale_output_factor + self.proj_attn = nn.Linear(channels, channels) + + def transpose_for_scores(self, projection: paddle.Tensor) -> paddle.Tensor: + new_projection_shape = projection.shape[:-1] + [self.num_heads, -1] + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.reshape(new_projection_shape).transpose( + [0, 2, 1, 3]) + return new_projection + + def forward(self, hidden_states): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.reshape([batch, channel, height * width + ]).transpose([0, 2, 1]) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + # transpose + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + # get scores + scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) + attention_scores = paddle.matmul(query_states * scale, + key_states * scale, + transpose_y=True) + attention_probs = F.softmax(attention_scores.astype("float32"), + axis=-1).astype(attention_scores.dtype) + + # compute attention output + context_states = paddle.matmul(attention_probs, value_states) + + context_states = context_states.transpose([0, 2, 1, 3]) + new_context_states_shape = context_states.shape[:-2] + [ + self.channels, + ] + context_states = context_states.reshape(new_context_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(context_states) + hidden_states = hidden_states.transpose([0, 2, 1]).reshape( + [batch, channel, height, width]) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + +class SpatialTransformer(nn.Layer): + """ + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None): + super().__init__() + self.n_heads = n_heads + self.d_head = d_head + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = nn.GroupNorm(num_groups=32, + num_channels=in_channels, + epsilon=1e-6) + + self.proj_in = nn.Conv2D(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.LayerList([ + BasicTransformerBlock(inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim) for d in range(depth) + ]) + + self.proj_out = nn.Conv2D(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = x.transpose([0, 2, 3, 1]).reshape([b, h * w, c]) + for block in self.transformer_blocks: + x = block(x, context=context) + x = x.reshape([b, h, w, c]).transpose([0, 3, 1, 2]) + x = self.proj_out(x) + return x + x_in + + +class BasicTransformerBlock(nn.Layer): + + def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + + def forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class CrossAttention(nn.Layer): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = context_dim if context_dim is not None else query_dim + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias_attr=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias_attr=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias_attr=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape( + [batch_size, seq_len, head_size, dim // head_size]) + tensor = tensor.transpose([0, 2, 1, 3]).reshape( + [batch_size * head_size, seq_len, dim // head_size]) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape( + [batch_size // head_size, head_size, seq_len, dim]) + tensor = tensor.transpose([0, 2, 1, 3]).reshape( + [batch_size // head_size, seq_len, dim * head_size]) + return tensor + + def forward(self, x, context=None, mask=None): + batch_size, sequence_length, dim = x.shape + + h = self.heads + + q = self.to_q(x) + context = context if context is not None else x + k = self.to_k(context) + v = self.to_v(context) + + q = self.reshape_heads_to_batch_dim(q) + k = self.reshape_heads_to_batch_dim(k) + v = self.reshape_heads_to_batch_dim(v) + + # sim = paddle.einsum("b i d, b j d -> b i j", q, k) * self.scale + sim = paddle.einsum("b i d, b j d -> b i j", q * self.scale, k) + + if mask is not None: + mask = mask.reshape([batch_size, -1]) + max_neg_value = -finfo(sim.dtype).max + mask = mask[:, None, :].expand([h, -1, -1]).astype(sim.dtype) + sim = sim * mask + (1 - mask) * max_neg_value + + # attention, what we cannot get enough of + attn = F.softmax(sim, axis=-1) + + out = paddle.einsum("b i j, b j d -> b i d", attn, v) + out = self.reshape_batch_dim_to_heads(out) + return self.to_out(out) + + +class FeedForward(nn.Layer): + + def __init__(self, dim, dim_out=None, mult=4, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + project_in = GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +# feedforward +class GEGLU(nn.Layer): + + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, axis=-1) + return x * F.gelu(gate) diff --git a/paddlenlp/transformers/stable_diffusion_utils/embeddings.py b/paddlenlp/transformers/stable_diffusion_utils/embeddings.py new file mode 100644 index 000000000000..f2767cb0306e --- /dev/null +++ b/paddlenlp/transformers/stable_diffusion_utils/embeddings.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import paddle +import paddle.nn as nn + + +def get_timestep_embedding( + timesteps, + embedding_dim, + flip_sin_to_cos=False, + downscale_freq_shift=1, + scale=1, + max_period=10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * paddle.arange( + start=0, end=half_dim, dtype="float32") + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = paddle.exp(exponent) + emb = timesteps[:, None].astype("float32") * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = paddle.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = paddle.concat(emb, paddle.zeros([emb.shape[0], 1]), axis=-1) + return emb + + +class TimestepEmbedding(nn.Layer): + + def __init__(self, channel, time_embed_dim, act_fn="silu"): + super().__init__() + + self.linear_1 = nn.Linear(channel, time_embed_dim) + self.act = None + if act_fn == "silu": + self.act = nn.Silu() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + + def forward(self, sample): + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Layer): + + def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class GaussianFourierProjection(nn.Layer): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__(self, embedding_size=256, scale=1.0): + super().__init__() + self.register_buffer("weight", paddle.randn((embedding_size, )) * scale) + + # to delete later + self.register_buffer("W", paddle.randn((embedding_size, )) * scale) + + self.weight = self.W + + def forward(self, x): + x = paddle.log(x) + x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi + out = paddle.concat([paddle.sin(x_proj), paddle.cos(x_proj)], axis=-1) + return out diff --git a/paddlenlp/transformers/stable_diffusion_utils/resnet.py b/paddlenlp/transformers/stable_diffusion_utils/resnet.py new file mode 100644 index 000000000000..df1e1bb5f7df --- /dev/null +++ b/paddlenlp/transformers/stable_diffusion_utils/resnet.py @@ -0,0 +1,611 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +def pad_new(x, pad, mode="constant", value=0): + new_pad = [] + for _ in range(x.ndim * 2 - len(pad)): + new_pad.append(0) + ndim = list(range(x.ndim - 1, 0, -1)) + axes_start = {} + for i, _pad in enumerate(pad): + if _pad < 0: + new_pad.append(0) + zhengshu, yushu = divmod(i, 2) + if yushu == 0: + axes_start[ndim[zhengshu]] = -_pad + else: + new_pad.append(_pad) + + padded = paddle.nn.functional.pad(x, new_pad, mode=mode, value=value) + padded_shape = paddle.shape(padded) + axes = [] + starts = [] + ends = [] + for k, v in axes_start.items(): + axes.append(k) + starts.append(v) + ends.append(padded_shape[k]) + assert v < padded_shape[k] + + if axes: + return padded.slice(axes=axes, starts=starts, ends=ends) + else: + return padded + + +class Upsample2D(nn.Layer): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=False, + out_channels=None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + conv = nn.Conv2DTranspose(channels, self.out_channels, 4, 2, 1) + elif use_conv: + conv = nn.Conv2D(self.channels, self.out_channels, 3, padding=1) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + x = self.conv(x) + else: + x = self.Conv2d_0(x) + + return x + + +class Downsample2D(nn.Layer): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv=False, + out_channels=None, + padding=1, + name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = nn.Conv2D(self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2D(kernel_size=stride, stride=stride) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + x = pad_new(x, pad, mode="constant", value=0) + + assert x.shape[1] == self.channels + x = self.conv(x) + + return x + + +class FirUpsample2D(nn.Layer): + + def __init__(self, + channels=None, + out_channels=None, + use_conv=False, + fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2D(channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + self.use_conv = use_conv + self.fir_kernel = fir_kernel + self.out_channels = out_channels + + def _upsample_2d(self, x, w=None, k=None, factor=2, gain=1): + """Fused `upsample_2d()` followed by `Conv2d()`. + + Args: + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary: + order. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as + `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + + # Setup filter kernel. + if k is None: + k = [1] * factor + + # setup kernel + k = np.asarray(k, dtype=np.float32) + if k.ndim == 1: + k = np.outer(k, k) + k /= np.sum(k) + + k = k * (gain * (factor**2)) + + if self.use_conv: + convH = w.shape[2] + convW = w.shape[3] + inC = w.shape[1] + + p = (k.shape[0] - factor) - (convW - 1) + + stride = (factor, factor) + # Determine data dimensions. + stride = [1, 1, factor, factor] + output_shape = ( + (x.shape[2] - 1) * factor + convH, + (x.shape[3] - 1) * factor + convW, + ) + output_padding = ( + output_shape[0] - (x.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (x.shape[3] - 1) * stride[1] - convW, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + inC = w.shape[1] + num_groups = x.shape[1] // inC + + # Transpose weights. + w = paddle.reshape(w, (num_groups, -1, inC, convH, convW)) + w = w[..., ::-1, ::-1].transpose([0, 2, 1, 3, 4]) + w = paddle.reshape(w, (num_groups * inC, -1, convH, convW)) + + x = F.conv2d_transpose(x, + w, + stride=stride, + output_padding=output_padding, + padding=0) + + x = upfirdn2d_native(x, + paddle.to_tensor(k), + pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) + else: + p = k.shape[0] - factor + x = upfirdn2d_native( + x, + paddle.to_tensor(k), + up=factor, + pad=((p + 1) // 2 + factor - 1, p // 2), + ) + + return x + + def forward(self, x): + if self.use_conv: + h = self._upsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) + h = h + self.Conv2d_0.bias.reshape([1, -1, 1, 1]) + else: + h = self._upsample_2d(x, k=self.fir_kernel, factor=2) + + return h + + +class FirDownsample2D(nn.Layer): + + def __init__(self, + channels=None, + out_channels=None, + use_conv=False, + fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2D(channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + self.fir_kernel = fir_kernel + self.use_conv = use_conv + self.out_channels = out_channels + + def _downsample_2d(self, x, w=None, k=None, factor=2, gain=1): + """Fused `Conv2d()` followed by `downsample_2d()`. + + Args: + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary: + order. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH, + filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // + numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * + factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: + Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same + datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + if k is None: + k = [1] * factor + + # setup kernel + k = np.asarray(k, dtype=np.float32) + if k.ndim == 1: + k = np.outer(k, k) + k /= np.sum(k) + + k = k * gain + + if self.use_conv: + _, _, convH, convW = w.shape + p = (k.shape[0] - factor) + (convW - 1) + s = [factor, factor] + x = upfirdn2d_native(x, + paddle.to_tensor(k), + pad=((p + 1) // 2, p // 2)) + x = F.conv2d(x, w, stride=s, padding=0) + else: + p = k.shape[0] - factor + x = upfirdn2d_native(x, + paddle.to_tensor(k), + down=factor, + pad=((p + 1) // 2, p // 2)) + + return x + + def forward(self, x): + if self.use_conv: + x = self._downsample_2d(x, + w=self.Conv2d_0.weight, + k=self.fir_kernel) + x = x + self.Conv2d_0.bias.reshape([1, -1, 1, 1]) + else: + x = self._downsample_2d(x, k=self.fir_kernel, factor=2) + + return x + + +class ResnetBlock2D(nn.Layer): + + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + kernel=None, + output_scale_factor=1.0, + use_nin_shortcut=None, + up=False, + down=False, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = nn.GroupNorm(num_groups=groups, + num_channels=in_channels, + epsilon=eps) + + self.conv1 = nn.Conv2D(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + if temb_channels is not None: + self.time_emb_proj = nn.Linear(temb_channels, out_channels) + else: + self.time_emb_proj = None + + self.norm2 = nn.GroupNorm(num_groups=groups_out, + num_channels=out_channels, + epsilon=eps) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2D(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.Silu() + + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, k=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, + scale_factor=2.0, + mode="nearest") + else: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, k=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample2D(in_channels, + use_conv=False, + padding=1, + name="op") + + self.use_nin_shortcut = (self.in_channels != self.out_channels if + use_nin_shortcut is None else use_nin_shortcut) + + self.conv_shortcut = None + if self.use_nin_shortcut: + self.conv_shortcut = nn.Conv2D(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb, hey=False): + h = x + + # make sure hidden states is in float32 + # when running in half-precision + h = self.norm1(h.astype("float32")).astype(h.dtype) + h = self.nonlinearity(h) + + if self.upsample is not None: + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) + + h = self.conv1(h) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + h = h + temb + + # make sure hidden states is in float32 + # when running in half-precision + h = self.norm2(h.astype("float32")).astype(h.dtype) + h = self.nonlinearity(h) + + h = self.dropout(h) + h = self.conv2(h) + + if self.conv_shortcut is not None: + x = self.conv_shortcut(x) + + out = (x + h) / self.output_scale_factor + + return out + + +class Mish(nn.Layer): + + def forward(self, x): + return x * F.tanh(F.softplus(x)) + + +def upsample_2d(x, k=None, factor=2, gain=1): + r"""Upsample2D a batch of 2D images with the given filter. + + Args: + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given + filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a: + multiple of the upsampling factor. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if k is None: + k = [1] * factor + + k = np.asarray(k, dtype=np.float32) + if k.ndim == 1: + k = np.outer(k, k) + k /= np.sum(k) + + k = k * (gain * (factor**2)) + p = k.shape[0] - factor + return upfirdn2d_native(x, + paddle.to_tensor(k), + up=factor, + pad=((p + 1) // 2 + factor - 1, p // 2)) + + +def downsample_2d(x, k=None, factor=2, gain=1): + r"""Downsample2D a batch of 2D images with the given filter. + + Args: + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the + given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the + specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its + shape is a multiple of the downsampling factor. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if k is None: + k = [1] * factor + + k = np.asarray(k, dtype=np.float32) + if k.ndim == 1: + k = np.outer(k, k) + k /= np.sum(k) + + k = k * gain + p = k.shape[0] - factor + return upfirdn2d_native(x, + paddle.to_tensor(k), + down=factor, + pad=((p + 1) // 2, p // 2)) + + +def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): + up_x = up_y = up + down_x = down_y = down + pad_x0 = pad_y0 = pad[0] + pad_x1 = pad_y1 = pad[1] + + _, channel, in_h, in_w = input.shape + input = input.reshape([-1, in_h, in_w, 1]) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.reshape([-1, in_h, 1, in_w, 1, minor]) + # TODO + out = pad_new(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.reshape([-1, in_h * up_y, in_w * up_x, minor]) + + out = pad_new( + out, + [0, 0, + max(pad_x0, 0), + max(pad_x1, 0), + max(pad_y0, 0), + max(pad_y1, 0)]) + out = out[:, + max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] + + out = out.transpose([0, 3, 1, 2]) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = paddle.flip(kernel, [0, 1]).reshape([1, 1, kernel_h, kernel_w]) + out = F.conv2d(out, w) + out = out.reshape([ + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ]) + out = out.transpose([0, 2, 3, 1]) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.reshape([-1, channel, out_h, out_w]) diff --git a/paddlenlp/transformers/stable_diffusion_utils/schedulers.py b/paddlenlp/transformers/stable_diffusion_utils/schedulers.py new file mode 100644 index 000000000000..9e419c8cc807 --- /dev/null +++ b/paddlenlp/transformers/stable_diffusion_utils/schedulers.py @@ -0,0 +1,703 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union +import math +import numpy as np +import paddle +from scipy import integrate + +__all__ = [ + "SchedulerMixin", + "DDIMScheduler", + "LMSDiscreteScheduler", + "PNDMScheduler", +] + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t + from 0 to 1 and + produces the cumulative product of (1-beta) up to that part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2)**2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas, dtype=np.float32) + + +class SchedulerMixin: + + def set_format(self, tensor_format="pd"): + self.tensor_format = tensor_format + if tensor_format == "pd": + for key, value in vars(self).items(): + if isinstance(value, np.ndarray): + setattr(self, key, paddle.to_tensor(value)) + + return self + + def clip(self, tensor, min_value=None, max_value=None): + tensor_format = getattr(self, "tensor_format", "pd") + + if tensor_format == "np": + return np.clip(tensor, min_value, max_value) + elif tensor_format == "pd": + return paddle.clip(tensor, min_value, max_value) + + raise ValueError( + f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def log(self, tensor): + tensor_format = getattr(self, "tensor_format", "pd") + + if tensor_format == "np": + return np.log(tensor) + elif tensor_format == "pd": + return paddle.log(tensor) + + raise ValueError( + f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def match_shape( + self, + values: Union[np.ndarray, paddle.Tensor], + broadcast_array: Union[np.ndarray, paddle.Tensor], + ): + """ + Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. + + Args: + values: an array or tensor of values to extract. + broadcast_array: an array with a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + Returns: + a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + + tensor_format = getattr(self, "tensor_format", "pd") + values = values.flatten() + + while len(values.shape) < len(broadcast_array.shape): + values = values[..., None] + + return values + + def norm(self, tensor): + tensor_format = getattr(self, "tensor_format", "pd") + if tensor_format == "np": + return np.linalg.norm(tensor) + elif tensor_format == "pd": + return paddle.norm(tensor.reshape([tensor.shape[0], -1]), + axis=-1).mean() + + raise ValueError( + f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def randn_like(self, tensor): + tensor_format = getattr(self, "tensor_format", "pd") + if tensor_format == "np": + return np.random.randn(np.shape(tensor)) + elif tensor_format == "pd": + return paddle.randn(tensor.shape) + + raise ValueError( + f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def zeros_like(self, tensor): + tensor_format = getattr(self, "tensor_format", "pd") + if tensor_format == "np": + return np.zeros_like(tensor) + elif tensor_format == "pd": + return paddle.zeros_like(tensor) + + raise ValueError( + f"`self.tensor_format`: {self.tensor_format} is not valid.") + + +class LMSDiscreteScheduler(SchedulerMixin): + + def __init__( + self, + num_train_timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="linear", + trained_betas=None, + timestep_values=None, + tensor_format="pd", + ): + """ + Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 + """ + self.num_train_timesteps = num_train_timesteps + self.beta_start = beta_start + self.beta_end = beta_end + self.beta_schedule = beta_schedule + self.trained_betas = trained_betas + self.timestep_values = timestep_values + self.tensor_format = tensor_format + + if beta_schedule == "linear": + self.betas = np.linspace(beta_start, + beta_end, + num_train_timesteps, + dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = (np.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=np.float32, + )**2) + else: + raise NotImplementedError( + f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod)**0.5 + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.derivatives = [] + self.set_format(tensor_format=tensor_format) + + def get_lms_coefficient(self, order, t, current_order): + """ + Compute a linear multistep coefficient + """ + + def lms_derivative(tau): + prod = 1.0 + for k in range(order): + if current_order == k: + continue + prod *= (tau - self.sigmas[t - k]) / ( + self.sigmas[t - current_order] - self.sigmas[t - k]) + return prod + + integrated_coeff = integrate.quad(lms_derivative, + self.sigmas[t], + self.sigmas[t + 1], + epsrel=1e-4)[0] + + return integrated_coeff + + def set_timesteps(self, num_inference_steps): + self.num_inference_steps = num_inference_steps + self.timesteps = np.linspace(self.num_train_timesteps - 1, + 0, + num_inference_steps, + dtype=float) + + low_idx = np.floor(self.timesteps).astype(int) + high_idx = np.ceil(self.timesteps).astype(int) + frac = np.mod(self.timesteps, 1.0) + sigmas = np.array( + ((1 - self.alphas_cumprod) / self.alphas_cumprod)**0.5) + sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + self.sigmas = np.concatenate([sigmas, [0.0]]) + + self.derivatives = [] + self.set_format(tensor_format=self.tensor_format) + + def step( + self, + model_output: Union[paddle.Tensor, np.ndarray], + timestep: int, + sample: Union[paddle.Tensor, np.ndarray], + order: int = 4, + ): + sigma = self.sigmas[timestep] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma * model_output + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + self.derivatives.append(derivative) + if len(self.derivatives) > order: + self.derivatives.pop(0) + + # 3. Compute linear multistep coefficients + order = min(timestep + 1, order) + lms_coeffs = [ + self.get_lms_coefficient(order, timestep, curr_order) + for curr_order in range(order) + ] + + # 4. Compute previous sample based on the derivatives path + prev_sample = sample + sum(coeff * derivative + for coeff, derivative in zip( + lms_coeffs, reversed(self.derivatives))) + + return {"prev_sample": prev_sample} + + def add_noise(self, original_samples, noise, timesteps): + sigmas = self.match_shape(self.sigmas[timesteps], noise) + noisy_samples = original_samples + noise * sigmas + return noisy_samples + + def __len__(self): + return self.num_train_timesteps + + +class PNDMScheduler(SchedulerMixin): + + def __init__( + self, + num_train_timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="linear", + skip_prk_steps=False, + tensor_format="pd", + ): + self.num_train_timesteps = num_train_timesteps + self.beta_start = beta_start + self.beta_end = beta_end + self.beta_schedule = beta_schedule + self.skip_prk_steps = skip_prk_steps + self.tensor_format = tensor_format + + if beta_schedule == "linear": + self.betas = np.linspace(beta_start, + beta_end, + num_train_timesteps, + dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = (np.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=np.float32, + )**2) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError( + f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + + self.one = np.array(1.0) + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + # running values + self.cur_model_output = 0 + self.counter = 0 + self.cur_sample = None + self.ets = [] + + # setable values + self.num_inference_steps = None + self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self._offset = 0 + self.prk_timesteps = None + self.plms_timesteps = None + self.timesteps = None + self.set_format(tensor_format=tensor_format) + + def set_timesteps(self, num_inference_steps, offset=0): + self.num_inference_steps = num_inference_steps + self._timesteps = list( + range( + 0, + self.num_train_timesteps, + self.num_train_timesteps // num_inference_steps, + )) + self._offset = offset + self._timesteps = np.array([t + self._offset for t in self._timesteps]) + + if self.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + self.prk_timesteps = np.array([]) + self.plms_timesteps = np.concatenate([ + self._timesteps[:-1], self._timesteps[-2:-1], + self._timesteps[-1:] + ])[::-1].copy() + else: + prk_timesteps = np.array( + self._timesteps[-self.pndm_order:]).repeat(2) + np.tile( + np.array([ + 0, self.num_train_timesteps // num_inference_steps // 2 + ]), + self.pndm_order, + ) + self.prk_timesteps = ( + prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() + self.plms_timesteps = self._timesteps[:-3][::-1].copy( + ) # we copy to avoid having negative strides which are not supported by paddle.to_tensor + + self.timesteps = np.concatenate( + [self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + + self.ets = [] + self.counter = 0 + self.set_format(tensor_format=self.tensor_format) + + def step( + self, + model_output: Union[paddle.Tensor, np.ndarray], + timestep: int, + sample: Union[paddle.Tensor, np.ndarray], + ): + if self.counter < len(self.prk_timesteps) and not self.skip_prk_steps: + return self.step_prk(model_output=model_output, + timestep=timestep, + sample=sample) + else: + return self.step_plms(model_output=model_output, + timestep=timestep, + sample=sample) + + def step_prk( + self, + model_output: Union[paddle.Tensor, np.ndarray], + timestep: int, + sample: Union[paddle.Tensor, np.ndarray], + ): + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + diff_to_prev = (0 if self.counter % 2 else self.num_train_timesteps // + self.num_inference_steps // 2) + prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1]) + timestep = self.prk_timesteps[self.counter // 4 * 4] + + if self.counter % 4 == 0: + self.cur_model_output += 1 / 6 * model_output + self.ets.append(model_output) + self.cur_sample = sample + elif (self.counter - 1) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 2) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 3) % 4 == 0: + model_output = self.cur_model_output + 1 / 6 * model_output + self.cur_model_output = 0 + + # cur_sample should not be `None` + cur_sample = self.cur_sample if self.cur_sample is not None else sample + + prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, + model_output) + self.counter += 1 + + return {"prev_sample": prev_sample} + + def step_plms( + self, + model_output: Union[paddle.Tensor, np.ndarray], + timestep: int, + sample: Union[paddle.Tensor, np.ndarray], + ): + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + if not self.skip_prk_steps and len(self.ets) < 3: + raise ValueError( + f"{self.__class__} can only be run AFTER scheduler has been run " + "in 'prk' mode for at least 12 iterations " + "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " + "for more information.") + + prev_timestep = max( + timestep - self.num_train_timesteps // self.num_inference_steps, 0) + + if self.counter != 1: + self.ets.append(model_output) + else: + prev_timestep = timestep + timestep = (timestep + + self.num_train_timesteps // self.num_inference_steps) + + if len(self.ets) == 1 and self.counter == 0: + model_output = model_output + self.cur_sample = sample + elif len(self.ets) == 1 and self.counter == 1: + model_output = (model_output + self.ets[-1]) / 2 + sample = self.cur_sample + self.cur_sample = None + elif len(self.ets) == 2: + model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + + 5 * self.ets[-3]) / 12 + else: + model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + + 37 * self.ets[-3] - 9 * self.ets[-4]) + + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, + model_output) + self.counter += 1 + + return {"prev_sample": prev_sample} + + def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation + + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # model_output -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset] + alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - + self._offset] + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t)**(0.5) + + # corresponds to denominator of e_θ(x_t, t) in formula (9) + model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev**(0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev)**(0.5) + + # full formula (9) + prev_sample = (sample_coeff * sample - + (alpha_prod_t_prev - alpha_prod_t) * model_output / + model_output_denom_coeff) + + return prev_sample + + def add_noise(self, original_samples, noise, timesteps): + sqrt_alpha_prod = self.alphas_cumprod[timesteps]**0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps])**0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, + original_samples) + + noisy_samples = (sqrt_alpha_prod * original_samples + + sqrt_one_minus_alpha_prod * noise) + return noisy_samples + + def __len__(self): + return self.num_train_timesteps + + +class DDIMScheduler( + SchedulerMixin, ): + + def __init__( + self, + num_train_timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="linear", + trained_betas=None, + timestep_values=None, + clip_sample=True, + set_alpha_to_one=True, + tensor_format="pd", + ): + self.num_train_timesteps = num_train_timesteps + self.beta_start = beta_start + self.beta_end = beta_end + self.beta_schedule = beta_schedule + self.trained_betas = trained_betas + self.timestep_values = timestep_values + self.clip_sample = clip_sample + self.set_alpha_to_one = set_alpha_to_one + self.tensor_format = tensor_format + + if beta_schedule == "linear": + self.betas = np.linspace(beta_start, + beta_end, + num_train_timesteps, + dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = (np.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=np.float32, + )**2) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError( + f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this paratemer simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = (np.array(1.0) if set_alpha_to_one else + self.alphas_cumprod[0]) + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = (self.alphas_cumprod[prev_timestep] if + prev_timestep >= 0 else self.final_alpha_cumprod) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / + beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def set_timesteps(self, num_inference_steps, offset=0): + self.num_inference_steps = num_inference_steps + self.timesteps = np.arange( + 0, + self.num_train_timesteps, + self.num_train_timesteps // self.num_inference_steps, + )[::-1].copy() + self.timesteps += offset + self.set_format(tensor_format=self.tensor_format) + + def step(self, + model_output: Union[paddle.Tensor, np.ndarray], + timestep: int, + sample: Union[paddle.Tensor, np.ndarray], + eta: float = 0.0, + use_clipped_model_output: bool = False): + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointingc to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = (timestep - + self.num_train_timesteps // self.num_inference_steps) + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = (self.alphas_cumprod[prev_timestep] if + prev_timestep >= 0 else self.final_alpha_cumprod) + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (sample - beta_prod_t** + (0.5) * model_output) / alpha_prod_t**(0.5) + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = self.clip(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance**(0.5) + + if use_clipped_model_output: + # the model_output is always re-derived from the clipped x_0 in Glide + model_output = (sample - alpha_prod_t** + (0.5) * pred_original_sample) / beta_prod_t**(0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - + std_dev_t**2)**(0.5) * model_output + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = (alpha_prod_t_prev**(0.5) * pred_original_sample + + pred_sample_direction) + + if eta > 0: + noise = paddle.randn(model_output.shape) + variance = (self._get_variance(timestep, prev_timestep)**(0.5) * + eta * noise) + + if not paddle.is_tensor(model_output): + variance = variance.numpy() + + prev_sample = prev_sample + variance + + return {"prev_sample": prev_sample} + + def add_noise(self, original_samples, noise, timesteps): + sqrt_alpha_prod = self.alphas_cumprod[timesteps]**0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps])**0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, + original_samples) + + noisy_samples = (sqrt_alpha_prod * original_samples + + sqrt_one_minus_alpha_prod * noise) + return noisy_samples + + def __len__(self): + return self.num_train_timesteps diff --git a/paddlenlp/transformers/stable_diffusion_utils/unet_2d_condition.py b/paddlenlp/transformers/stable_diffusion_utils/unet_2d_condition.py new file mode 100644 index 000000000000..86e3a75601d8 --- /dev/null +++ b/paddlenlp/transformers/stable_diffusion_utils/unet_2d_condition.py @@ -0,0 +1,235 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Union + +import paddle +import paddle.nn as nn + +from .embeddings import TimestepEmbedding, Timesteps +from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block + + +class UNet2DConditionModel(nn.Layer): + + def __init__( + self, + sample_size=64, + in_channels=4, + out_channels=4, + center_input_sample=False, + flip_sin_to_cos=True, + freq_shift=0, + down_block_types=( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + ), + block_out_channels=(320, 640, 1280, 1280), + layers_per_block=2, + downsample_padding=1, + mid_block_scale_factor=1, + act_fn="silu", + norm_num_groups=32, + norm_eps=1e-5, + cross_attention_dim=768, + attention_head_dim=8, + ): + super().__init__() + self.in_channels = in_channels + self.sample_size = sample_size + self.center_input_sample = center_input_sample + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = nn.Conv2D(in_channels, + block_out_channels[0], + kernel_size=3, + padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, + freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, + time_embed_dim) + + self.down_blocks = nn.LayerList([]) + self.mid_block = None + self.up_blocks = nn.LayerList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + downsample_padding=downsample_padding, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + resnet_groups=norm_num_groups, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min( + i + 1, + len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + epsilon=norm_eps, + ) + self.conv_act = nn.Silu() + self.conv_out = nn.Conv2D(block_out_channels[0], + out_channels, + 3, + padding=1) + + def forward( + self, + sample: paddle.Tensor, + timestep: Union[paddle.Tensor, float, int], + encoder_hidden_states: paddle.Tensor, + ) -> Dict[str, paddle.Tensor]: + + # 0. center input if necessary + if self.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not paddle.is_tensor(timesteps): + timesteps = paddle.to_tensor([timesteps], dtype="int64") + elif paddle.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None] + + # broadcast to batch dimension + timesteps = paddle.broadcast_to(timesteps, [sample.shape[0]]) + + t_emb = self.time_proj(timesteps) + emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample, ) + for downsample_block in self.down_blocks: + + if (hasattr(downsample_block, "attentions") + and downsample_block.attentions is not None): + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, + temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, + emb, + encoder_hidden_states=encoder_hidden_states) + + # 5. up + for upsample_block in self.up_blocks: + + res_samples = down_block_res_samples[-len(upsample_block.resnets):] + down_block_res_samples = down_block_res_samples[:-len(upsample_block + .resnets)] + + if (hasattr(upsample_block, "attentions") + and upsample_block.attentions is not None): + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample = upsample_block(hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples) + + # 6. post-process + # make sure hidden states is in float32 + # when running in half-precision + sample = self.conv_norm_out(sample.astype("float32")).astype( + sample.dtype) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + output = {"sample": sample} + + return output diff --git a/paddlenlp/transformers/stable_diffusion_utils/unet_blocks.py b/paddlenlp/transformers/stable_diffusion_utils/unet_blocks.py new file mode 100644 index 000000000000..e29b040b01e6 --- /dev/null +++ b/paddlenlp/transformers/stable_diffusion_utils/unet_blocks.py @@ -0,0 +1,1513 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.nn as nn + +from .attention import AttentionBlock, SpatialTransformer +from .resnet import ( + Downsample2D, + FirDownsample2D, + FirUpsample2D, + ResnetBlock2D, + Upsample2D, +) + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + cross_attention_dim=None, + downsample_padding=None, +): + down_block_type = (down_block_type[7:] + if down_block_type.startswith("UNetRes") else + down_block_type) + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + ) + elif down_block_type == "AttnDownBlock2D": + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + ) + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + cross_attention_dim=None, +): + up_block_type = (up_block_type[7:] + if up_block_type.startswith("UNetRes") else up_block_type) + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + elif up_block_type == "AttnUpBlock2D": + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attn_num_head_channels=attn_num_head_channels, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attn_num_head_channels=attn_num_head_channels, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock2D(nn.Layer): + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + **kwargs, + ): + super().__init__() + + self.attention_type = attention_type + resnet_groups = (resnet_groups if resnet_groups is not None else min( + in_channels // 4, 32)) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + AttentionBlock( + in_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + num_groups=resnet_groups, + )) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + self.attentions = nn.LayerList(attentions) + self.resnets = nn.LayerList(resnets) + + def forward(self, hidden_states, temb=None, encoder_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.attention_type == "default": + hidden_states = attn(hidden_states) + else: + hidden_states = attn(hidden_states, encoder_states) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DCrossAttn(nn.Layer): + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + cross_attention_dim=1280, + **kwargs, + ): + super().__init__() + + self.attention_type = attention_type + resnet_groups = (resnet_groups if resnet_groups is not None else min( + in_channels // 4, 32)) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + SpatialTransformer( + in_channels, + attn_num_head_channels, + in_channels // attn_num_head_channels, + depth=1, + context_dim=cross_attention_dim, + )) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + self.attentions = nn.LayerList(attentions) + self.resnets = nn.LayerList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class AttnDownBlock2D(nn.Layer): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + )) + + self.attentions = nn.LayerList(attentions) + self.resnets = nn.LayerList(resnets) + + if add_downsample: + self.downsamplers = nn.LayerList([ + Downsample2D( + in_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ]) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states, ) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states, ) + + return hidden_states, output_states + + +class CrossAttnDownBlock2D(nn.Layer): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + attentions.append( + SpatialTransformer( + out_channels, + attn_num_head_channels, + out_channels // attn_num_head_channels, + depth=1, + context_dim=cross_attention_dim, + )) + self.attentions = nn.LayerList(attentions) + self.resnets = nn.LayerList(resnets) + + if add_downsample: + self.downsamplers = nn.LayerList([ + Downsample2D( + in_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ]) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=encoder_hidden_states) + output_states += (hidden_states, ) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states, ) + + return hidden_states, output_states + + +class DownBlock2D(nn.Layer): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + self.resnets = nn.LayerList(resnets) + + if add_downsample: + self.downsamplers = nn.LayerList([ + Downsample2D( + in_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ]) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states, ) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states, ) + + return hidden_states, output_states + + +class DownEncoderBlock2D(nn.Layer): + + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + self.resnets = nn.LayerList(resnets) + + if add_downsample: + self.downsamplers = nn.LayerList([ + Downsample2D( + in_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ]) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnDownEncoderBlock2D(nn.Layer): + + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + attentions = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + num_groups=resnet_groups, + )) + + self.attentions = nn.LayerList(attentions) + self.resnets = nn.LayerList(resnets) + + if add_downsample: + self.downsamplers = nn.LayerList([ + Downsample2D( + in_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ]) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnSkipDownBlock2D(nn.Layer): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=np.sqrt(2.0), + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + self.attentions = nn.LayerList([]) + self.resnets = nn.LayerList([]) + + self.attention_type = attention_type + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + self.attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + )) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_nin_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.LayerList( + [FirDownsample2D(in_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2D(3, + out_channels, + kernel_size=(1, 1), + stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states, ) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states, ) + + return hidden_states, output_states, skip_sample + + +class SkipDownBlock2D(nn.Layer): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor=np.sqrt(2.0), + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + self.resnets = nn.LayerList([]) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_nin_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.LayerList( + [FirDownsample2D(in_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2D(3, + out_channels, + kernel_size=(1, 1), + stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states, ) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states, ) + + return hidden_states, output_states, skip_sample + + +class AttnUpBlock2D(nn.Layer): + + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_type="default", + attn_num_head_channels=1, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - + 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + )) + + self.attentions = nn.LayerList(attentions) + self.resnets = nn.LayerList(resnets) + + if add_upsample: + self.upsamplers = nn.LayerList([ + Upsample2D(out_channels, + use_conv=True, + out_channels=out_channels) + ]) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + for resnet, attn in zip(self.resnets, self.attentions): + + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = paddle.concat([hidden_states, res_hidden_states], + axis=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Layer): + + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - + 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + attentions.append( + SpatialTransformer( + out_channels, + attn_num_head_channels, + out_channels // attn_num_head_channels, + depth=1, + context_dim=cross_attention_dim, + )) + self.attentions = nn.LayerList(attentions) + self.resnets = nn.LayerList(resnets) + + if add_upsample: + self.upsamplers = nn.LayerList([ + Upsample2D(out_channels, + use_conv=True, + out_channels=out_channels) + ]) + else: + self.upsamplers = None + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = paddle.concat([hidden_states, res_hidden_states], + axis=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=encoder_hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UpBlock2D(nn.Layer): + + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - + 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + self.resnets = nn.LayerList(resnets) + + if add_upsample: + self.upsamplers = nn.LayerList([ + Upsample2D(out_channels, + use_conv=True, + out_channels=out_channels) + ]) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + for resnet in self.resnets: + + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = paddle.concat([hidden_states, res_hidden_states], + axis=1) + + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UpDecoderBlock2D(nn.Layer): + + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + self.resnets = nn.LayerList(resnets) + + if add_upsample: + self.upsamplers = nn.LayerList([ + Upsample2D(out_channels, + use_conv=True, + out_channels=out_channels) + ]) + else: + self.upsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnUpDecoderBlock2D(nn.Layer): + + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + num_groups=resnet_groups, + )) + + self.attentions = nn.LayerList(attentions) + self.resnets = nn.LayerList(resnets) + + if add_upsample: + self.upsamplers = nn.LayerList([ + Upsample2D(out_channels, + use_conv=True, + out_channels=out_channels) + ]) + else: + self.upsamplers = None + + def forward(self, hidden_states): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnSkipUpBlock2D(nn.Layer): + + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=np.sqrt(2.0), + upsample_padding=1, + add_upsample=True, + ): + super().__init__() + self.attentions = nn.LayerList([]) + self.resnets = nn.LayerList([]) + + self.attention_type = attention_type + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - + 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(resnet_in_channels + res_skip_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + self.attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + )) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_nin_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2D(out_channels, + 3, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1)) + self.skip_norm = nn.GroupNorm( + num_groups=min(out_channels // 4, 32), + num_channels=out_channels, + eps=resnet_eps, + affine=True, + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward(self, + hidden_states, + res_hidden_states_tuple, + temb=None, + skip_sample=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = paddle.concat([hidden_states, res_hidden_states], + axis=1) + + hidden_states = resnet(hidden_states, temb) + + hidden_states = self.attentions[0](hidden_states) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class SkipUpBlock2D(nn.Layer): + + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor=np.sqrt(2.0), + add_upsample=True, + upsample_padding=1, + ): + super().__init__() + self.resnets = nn.LayerList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - + 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min((resnet_in_channels + res_skip_channels) // 4, + 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_nin_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2D(out_channels, + 3, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1)) + self.skip_norm = nn.GroupNorm( + num_groups=min(out_channels // 4, 32), + num_channels=out_channels, + eps=resnet_eps, + affine=True, + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward(self, + hidden_states, + res_hidden_states_tuple, + temb=None, + skip_sample=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = paddle.concat([hidden_states, res_hidden_states], + axis=1) + + hidden_states = resnet(hidden_states, temb) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample diff --git a/paddlenlp/transformers/stable_diffusion_utils/utils.py b/paddlenlp/transformers/stable_diffusion_utils/utils.py new file mode 100644 index 000000000000..a11482596a03 --- /dev/null +++ b/paddlenlp/transformers/stable_diffusion_utils/utils.py @@ -0,0 +1,486 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import random + +import numpy as np +import paddle +from PIL import Image +from tqdm.auto import tqdm +from typing import Optional + +from .schedulers import PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler +from ..image_utils import load_image + +__all__ = ["StableDiffusionMixin"] + + +class StableDiffusionMixin: + + def set_scheduler(self, scheduler): + if isinstance(scheduler, + (PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler)): + self.scheduler = scheduler + elif isinstance(scheduler, str): + if scheduler == "pndm": + self.scheduler = PNDMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + skip_prk_steps=True, + ) + elif scheduler == "ddim": + self.scheduler = DDIMScheduler(beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False) + elif scheduler == "k-lms": + self.scheduler = LMSDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear") + else: + raise ValueError( + 'scheduler must be in ["pndm", "ddim", "k-lms"].') + else: + raise ValueError('scheduler error.') + + @classmethod + def preprocess_image(cls, image): + image = load_image(image) + w, h = image.size + w, h = map(lambda x: x - x % 32, + (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose([0, 3, 1, 2]) + image = paddle.to_tensor(image) + return 2.0 * image - 1.0 + + @classmethod + def preprocess_mask(cls, mask): + mask = load_image(mask) + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, + (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=Image.NEAREST) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose([0, 1, 2, 3]) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = paddle.to_tensor(mask) + return mask + + @paddle.no_grad() + def stable_diffusion_text2image( + self, + input_ids, + seed: Optional[int] = None, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + latents: Optional[paddle.Tensor] = None, + fp16: Optional[bool] = False, + ): + batch_size = input_ids.shape[0] + if height % 64 != 0 or width % 64 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 64 but are {height} and {width}." + ) + + with paddle.amp.auto_cast(enable=fp16, level="O1"): + text_embeddings = self.clip.text_model(input_ids)[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + uncond_embeddings = self.clip.text_model( + self.input_ids_uncond.expand([batch_size, -1]))[0] + text_embeddings = paddle.concat( + [uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + latents_shape = [ + batch_size, self.unet_model.in_channels, height // 8, width // 8 + ] + if latents is None: + if seed is None: + seed = random.randint(0, 2**32) + paddle.seed(seed) + latents = paddle.randn(latents_shape) + else: + if latents.shape != latents_shape: + raise ValueError( + f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}" + ) + + # set timesteps + accepts_offset = "offset" in set( + inspect.signature( + self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, + **extra_set_kwargs) + + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in tqdm(enumerate(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = (paddle.concat([latents] * 2) if + do_classifier_free_guidance else latents) + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[i] + latent_model_input = latent_model_input / ( + (sigma**2 + 1)**0.5) + + # predict the noise residual + noise_pred = self.unet_model( + latent_model_input, + t, + encoder_hidden_states=text_embeddings)["sample"] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step( + noise_pred, i, latents, + **extra_step_kwargs)["prev_sample"] + else: + latents = self.scheduler.step( + noise_pred, t, latents, + **extra_step_kwargs)["prev_sample"] + + # scale and decode the image latents with vae + image = self.vae_model.decode(1 / 0.18215 * latents) + image = (image / 2 + 0.5).clip(0, 1) + image = image.transpose([0, 2, 3, 1]).cpu().numpy() + image = (image * 255).round().astype(np.uint8) + image = [Image.fromarray(img) for img in image] + + return image + + @paddle.no_grad() + def stable_diffusion_image2image( + self, + input_ids, + init_image, + strength=0.8, + num_inference_steps=50, + guidance_scale=7.5, + eta=0.0, + seed=None, + fp16=False, + ): + batch_size = input_ids.shape[0] + + if strength < 0 or strength > 1: + raise ValueError( + f"The value of strength should in [0.0, 1.0] but is {strength}") + + with paddle.amp.auto_cast(enable=fp16, level="O1"): + # set timesteps + accepts_offset = "offset" in set( + inspect.signature( + self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + offset = 0 + if accepts_offset: + offset = 1 + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, + **extra_set_kwargs) + + # encode the init image into latents and scale the latents + init_latents = self.vae_model.encode(init_image).sample() + init_latents = 0.18215 * init_latents + + # prepare init_latents noise to latents + init_latents = paddle.concat([init_latents] * batch_size) + + # get the original timestep using init_timestep + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + if isinstance(self.scheduler, LMSDiscreteScheduler): + timesteps = paddle.to_tensor( + [num_inference_steps - init_timestep] * batch_size, + dtype="int64") + else: + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = paddle.to_tensor([timesteps] * batch_size, + dtype="int64") + + # add noise to latents using the timesteps + if seed is None: + seed = random.randint(0, 2**32) + paddle.seed(seed) + noise = paddle.randn(init_latents.shape) + init_latents = self.scheduler.add_noise(init_latents, noise, + timesteps) + + text_embeddings = self.clip.text_model(input_ids)[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + uncond_embeddings = self.clip.text_model( + self.input_ids_uncond.expand([batch_size, -1]))[0] + text_embeddings = paddle.concat( + [uncond_embeddings, text_embeddings]) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + t_start = max(num_inference_steps - init_timestep + offset, 0) + for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): + t_index = t_start + i + # expand the latents if we are doing classifier free guidance + latent_model_input = (paddle.concat([latents] * 2) if + do_classifier_free_guidance else latents) + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[t_index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ( + (sigma**2 + 1)**0.5) + latent_model_input = latent_model_input.astype( + paddle.get_default_dtype()) + t = t.astype(paddle.get_default_dtype()) + + # predict the noise residual + noise_pred = self.unet_model( + latent_model_input, + t, + encoder_hidden_states=text_embeddings)["sample"] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step( + noise_pred, t_index, latents, + **extra_step_kwargs)["prev_sample"] + else: + latents = self.scheduler.step( + noise_pred, t, latents, + **extra_step_kwargs)["prev_sample"] + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae_model.decode( + latents.astype(paddle.get_default_dtype())) + image = (image / 2 + 0.5).clip(0, 1) + image = image.transpose([0, 2, 3, 1]).cpu().numpy() + image = (image * 255).round().astype(np.uint8) + image = [Image.fromarray(img) for img in image] + return image + + @paddle.no_grad() + def stable_diffusion_inpainting( + self, + input_ids, + init_image, + mask_image, + strength=0.8, + num_inference_steps=50, + guidance_scale=7.5, + eta=0.0, + seed=None, + fp16=False, + ): + batch_size = input_ids.shape[0] + + if strength < 0 or strength > 1: + raise ValueError( + f"The value of strength should in [0.0, 1.0] but is {strength}") + + with paddle.amp.auto_cast(enable=fp16, level="O1"): + # set timesteps + accepts_offset = "offset" in set( + inspect.signature( + self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + offset = 0 + if accepts_offset: + offset = 1 + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, + **extra_set_kwargs) + + # encode the init image into latents and scale the latents + init_latents = self.vae_model.encode(init_image).sample() + init_latents = 0.18215 * init_latents + + # prepare init_latents noise to latents + init_latents = paddle.concat([init_latents] * batch_size) + init_latents_orig = init_latents + + mask = paddle.concat([mask_image] * batch_size) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError( + f"The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + if isinstance(self.scheduler, LMSDiscreteScheduler): + timesteps = paddle.to_tensor( + [num_inference_steps - init_timestep] * batch_size, + dtype="int64") + else: + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = paddle.to_tensor([timesteps] * batch_size, + dtype="int64") + + # add noise to latents using the timesteps + if seed is None: + seed = random.randint(0, 2**32) + paddle.seed(seed) + noise = paddle.randn(init_latents.shape) + init_latents = self.scheduler.add_noise(init_latents, noise, + timesteps) + + # get prompt text embeddings + text_embeddings = self.clip.text_model(input_ids)[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + uncond_embeddings = self.clip.text_model( + self.input_ids_uncond.expand([batch_size, -1]))[0] + text_embeddings = paddle.concat( + [uncond_embeddings, text_embeddings]) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + t_start = max(num_inference_steps - init_timestep + offset, 0) + for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): + t_index = t_start + i + # expand the latents if we are doing classifier free guidance + latent_model_input = (paddle.concat([latents] * 2) if + do_classifier_free_guidance else latents) + + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[t_index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ( + (sigma**2 + 1)**0.5) + latent_model_input = latent_model_input.astype( + paddle.get_default_dtype()) + t = t.astype(paddle.get_default_dtype()) + + # predict the noise residual + noise_pred = self.unet_model( + latent_model_input, + t, + encoder_hidden_states=text_embeddings)["sample"] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step( + noise_pred, t_index, latents, + **extra_step_kwargs)["prev_sample"] + # masking + init_latents_proper = self.scheduler.add_noise( + init_latents_orig, noise, t_index) + else: + latents = self.scheduler.step( + noise_pred, t, latents, + **extra_step_kwargs)["prev_sample"] + + # masking + init_latents_proper = self.scheduler.add_noise( + init_latents_orig, noise, t) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # scale and decode the image latents with vae + image = self.vae_model.decode(1 / 0.18215 * latents) + image = (image / 2 + 0.5).clip(0, 1) + image = image.transpose([0, 2, 3, 1]).cpu().numpy() + image = (image * 255).round().astype(np.uint8) + image = [Image.fromarray(img) for img in image] + return image diff --git a/paddlenlp/transformers/stable_diffusion_utils/vae.py b/paddlenlp/transformers/stable_diffusion_utils/vae.py new file mode 100644 index 000000000000..b18e5dc3a659 --- /dev/null +++ b/paddlenlp/transformers/stable_diffusion_utils/vae.py @@ -0,0 +1,519 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.nn as nn + +from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block + + +class Encoder(nn.Layer): + + def __init__( + self, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D", ), + block_out_channels=(64, ), + layers_per_block=2, + act_fn="silu", + double_z=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2D(in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1) + + self.mid_block = None + self.down_blocks = nn.LayerList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + attn_num_head_channels=None, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=32, + temb_channels=None, + ) + + # out + num_groups_out = 32 + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], + num_groups=num_groups_out, + epsilon=1e-6) + self.conv_act = nn.Silu() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2D(block_out_channels[-1], + conv_out_channels, + 3, + padding=1) + + def forward(self, x): + sample = x + sample = self.conv_in(sample) + + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class Decoder(nn.Layer): + + def __init__( + self, + in_channels=3, + out_channels=3, + up_block_types=("UpDecoderBlock2D", ), + block_out_channels=(64, ), + layers_per_block=2, + act_fn="silu", + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2D(in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1) + + self.mid_block = None + self.up_blocks = nn.LayerList([]) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=32, + temb_channels=None, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attn_num_head_channels=None, + temb_channels=None, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = 32 + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], + num_groups=num_groups_out, + epsilon=1e-6) + self.conv_act = nn.Silu() + self.conv_out = nn.Conv2D(block_out_channels[0], + out_channels, + 3, + padding=1) + + def forward(self, z): + sample = z + sample = self.conv_in(sample) + + # middle + sample = self.mid_block(sample) + + # up + for up_block in self.up_blocks: + sample = up_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class VectorQuantizer(nn.Layer): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__( + self, + n_e, + e_dim, + beta, + remap=None, + unknown_index="random", + sane_index_shape=False, + legacy=True, + ): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", paddle.to_tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape([ishape[0], -1]) + used = self.used + match = (inds[:, :, None] == used[None, None, ...]).astype("int64") + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = paddle.randint(0, + self.re_embed, + shape=new[unknown].shape) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape([ishape[0], -1]) + used = self.used + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = paddle.gather(used[None, :][inds.shape[0] * [0], :], + inds, + axis=1) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.transpose([0, 2, 3, 1]) + z_flattened = z.reshape([-1, self.e_dim]) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = (paddle.sum(z_flattened**2, axis=1, keepdim=True) + + paddle.sum(self.embedding.weight**2, axis=1) - 2 * + paddle.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())) + + min_encoding_indices = paddle.argmin(d, axis=1) + z_q = self.embedding(min_encoding_indices).reshape(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * paddle.mean((z_q.detach() - z)**2) + paddle.mean( + (z_q - z.detach())**2) + else: + loss = paddle.mean((z_q.detach() - z)**2) + self.beta * paddle.mean( + (z_q - z.detach())**2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = z_q.transpose([0, 3, 1, 2]) + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape( + [z.shape[0], -1]) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape([-1, + 1]) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + [z_q.shape[0], z_q.shape[2], z_q.shape[3]]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape([shape[0], -1]) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.flatten() # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.reshape(shape) + # reshape back to match original input shape + z_q = z_q.transpose([0, 3, 1, 2]) + + return z_q + + +class DiagonalGaussianDistribution(object): + + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = paddle.chunk(parameters, 2, axis=1) + self.logvar = paddle.clip(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = paddle.exp(0.5 * self.logvar) + self.var = paddle.exp(self.logvar) + if self.deterministic: + self.var = self.std = paddle.zeros_like(self.mean) + + def sample(self, seed=None): + if seed is not None: + paddle.seed(seed) + x = self.mean + self.std * paddle.randn(self.mean.shape) + return x + + def kl(self, other=None): + if self.deterministic: + return paddle.to_tensor([0.0]) + else: + if other is None: + return 0.5 * paddle.sum( + paddle.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + axis=[1, 2, 3], + ) + else: + return 0.5 * paddle.sum( + paddle.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + axis=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return paddle.to_tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * paddle.sum( + logtwopi + self.logvar + + paddle.pow(sample - self.mean, 2) / self.var, + axis=dims, + ) + + def mode(self): + return self.mean + + +class VQModel(nn.Layer): + + def __init__( + self, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D", ), + up_block_types=("UpDecoderBlock2D", ), + block_out_channels=(64, ), + layers_per_block=1, + act_fn="silu", + latent_channels=3, + sample_size=32, + num_vq_embeddings=256, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + double_z=False, + ) + + self.quant_conv = nn.Conv2D(latent_channels, latent_channels, 1) + self.quantize = VectorQuantizer( + num_vq_embeddings, + latent_channels, + beta=0.25, + remap=None, + sane_index_shape=False, + ) + self.post_quant_conv = nn.Conv2D(latent_channels, latent_channels, 1) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + ) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def forward(self, sample): + x = sample + h = self.encode(x) + dec = self.decode(h) + return dec + + +class AutoencoderKL(nn.Layer): + + def __init__( + self, + in_channels=3, + out_channels=3, + down_block_types=( + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ), + up_block_types=( + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ), + block_out_channels=(128, 256, 512, 512), + layers_per_block=2, + act_fn="silu", + latent_channels=4, + sample_size=512, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + ) + + self.quant_conv = nn.Conv2D(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv2D(latent_channels, latent_channels, 1) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, sample, sample_posterior=False): + x = sample + posterior = self.encode(x) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec