diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py
index 58d814796ede..6af11e887fb6 100644
--- a/paddlenlp/transformers/__init__.py
+++ b/paddlenlp/transformers/__init__.py
@@ -121,6 +121,10 @@
from .artist.tokenizer import *
from .dallebart.modeling import *
from .dallebart.tokenizer import *
+from .clip.modeling import *
+from .clip.feature_extraction import *
+from .clip.tokenizer import *
+from .clip.procesing import *
from .gptj.modeling import *
from .gptj.tokenizer import *
diff --git a/paddlenlp/transformers/auto/modeling.py b/paddlenlp/transformers/auto/modeling.py
index 706074d91be5..97aa79a20484 100644
--- a/paddlenlp/transformers/auto/modeling.py
+++ b/paddlenlp/transformers/auto/modeling.py
@@ -88,6 +88,7 @@
("Bart", "bart"),
("GAUAlpha", "gau_alpha"),
("CodeGen", "codegen"),
+ ("CLIP", "clip"),
("Artist", "artist"),
("OPT", 'opt')
])
diff --git a/paddlenlp/transformers/auto/tokenizer.py b/paddlenlp/transformers/auto/tokenizer.py
index a6d056bbcb4f..4b759ea8b659 100644
--- a/paddlenlp/transformers/auto/tokenizer.py
+++ b/paddlenlp/transformers/auto/tokenizer.py
@@ -79,6 +79,7 @@
("BartTokenizer", "bart"),
("GAUAlphaTokenizer", "gau_alpha"),
("CodeGenTokenizer", "codegen"),
+ ("CLIPTokenizer", "clip"),
("ArtistTokenizer", "artist"),
])
diff --git a/paddlenlp/transformers/clip/__init__.py b/paddlenlp/transformers/clip/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/paddlenlp/transformers/clip/feature_extraction.py b/paddlenlp/transformers/clip/feature_extraction.py
new file mode 100644
index 000000000000..d1c9e68646ea
--- /dev/null
+++ b/paddlenlp/transformers/clip/feature_extraction.py
@@ -0,0 +1,166 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for CLIP."""
+
+from typing import List, Optional, Union
+
+import paddle
+import numpy as np
+from PIL import Image
+
+from ..feature_extraction_utils import BatchFeature
+from ..tokenizer_utils_base import TensorType
+from ..image_utils import ImageFeatureExtractionMixin
+
+__all__ = ["CLIPFeatureExtractor"]
+
+
+class CLIPFeatureExtractor(ImageFeatureExtractionMixin):
+ r"""
+ Constructs a CLIP feature extractor.
+ This feature extractor inherits from [`ImageFeatureExtractionMixin`] which contains most of the main methods. Users
+ should refer to this superclass for more information regarding those methods.
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the input to a certain `size`.
+ size (`int`, *optional*, defaults to 224):
+ Resize the input to the given size. Only has an effect if `do_resize` is set to `True`.
+ resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`):
+ An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
+ `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
+ if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
+ image is padded with 0's and then center cropped.
+ crop_size (`int`, *optional*, defaults to 224):
+ Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the input with `image_mean` and `image_std`.
+ image_mean (`List[int]`, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
+ The sequence of means for each channel, to be used when normalizing images.
+ image_std (`List[int]`, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
+ The sequence of standard deviations for each channel, to be used when normalizing images.
+ convert_rgb (`bool`, defaults to `True`):
+ Whether or not to convert `PIL.Image.Image` into `RGB` format
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(self,
+ do_resize=True,
+ size=224,
+ resample=Image.BICUBIC,
+ do_center_crop=True,
+ crop_size=224,
+ do_normalize=True,
+ image_mean=None,
+ image_std=None,
+ do_convert_rgb=True,
+ **kwargs):
+ super().__init__()
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else [
+ 0.48145466, 0.4578275, 0.40821073
+ ]
+ self.image_std = image_std if image_std is not None else [
+ 0.26862954, 0.26130258, 0.27577711
+ ]
+ self.do_convert_rgb = do_convert_rgb
+
+ def __call__(
+ self,
+ images: Union[Image.Image, np.ndarray, "paddle.Tensor",
+ List[Image.Image], List[np.ndarray],
+ List["paddle.Tensor"] # noqa
+ ],
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs):
+ """
+ Main method to prepare for the model one or several image(s).
+
+ NumPy arrays and Paddle tensors are converted to PIL images when resizing, so the most efficient is to pass
+ PIL images.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `paddle.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[paddle.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or Paddle
+ tensor. In case of a NumPy array/Paddle tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'pd'`: Return Paddle `paddle.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+ - **pixel_values** -- Pixel values to be fed to a model.
+ """
+ # Input type checking for clearer error
+ valid_images = False
+
+ # Check that images has a valid type
+ if isinstance(images,
+ (Image.Image, np.ndarray)) or paddle.is_tensor(images):
+ valid_images = True
+ elif isinstance(images, (list, tuple)):
+ if len(images) == 0 or isinstance(
+ images[0],
+ (Image.Image, np.ndarray)) or paddle.is_tensor(images[0]):
+ valid_images = True
+
+ if not valid_images:
+ raise ValueError(
+ "Images must of type `PIL.Image.Image`, `np.ndarray` or `paddle.Tensor` (single example), "
+ "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[paddle.Tensor]` (batch of examples)."
+ )
+
+ is_batched = bool(
+ isinstance(images, (list, tuple))
+ and (isinstance(images[0], (Image.Image, np.ndarray))
+ or paddle.is_tensor(images[0])))
+
+ if not is_batched:
+ images = [images]
+
+ # transformations (convert rgb + resizing + center cropping + normalization)
+ if self.do_convert_rgb:
+ images = [self.convert_rgb(image) for image in images]
+ if self.do_resize and self.size is not None and self.resample is not None:
+ images = [
+ self.resize(image=image,
+ size=self.size,
+ resample=self.resample,
+ default_to_square=False) for image in images
+ ]
+ if self.do_center_crop and self.crop_size is not None:
+ images = [
+ self.center_crop(image, self.crop_size) for image in images
+ ]
+ if self.do_normalize:
+ images = [
+ self.normalize(image=image,
+ mean=self.image_mean,
+ std=self.image_std) for image in images
+ ]
+
+ # return as BatchFeature
+ data = {"pixel_values": images}
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ return encoded_inputs
diff --git a/paddlenlp/transformers/clip/modeling.py b/paddlenlp/transformers/clip/modeling.py
new file mode 100644
index 000000000000..95081f0a315d
--- /dev/null
+++ b/paddlenlp/transformers/clip/modeling.py
@@ -0,0 +1,1891 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2021 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Tuple, Optional, Union
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from dataclasses import dataclass
+from ..model_outputs import BaseModelOutputWithPoolingAndCrossAttentions, ModelOutput
+from .. import PretrainedModel, register_base_model
+from ..stable_diffusion_utils import StableDiffusionMixin, AutoencoderKL, PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, UNet2DConditionModel
+from ..guided_diffusion_utils import DiscoDiffusionMixin, create_gaussian_diffusion, create_unet_model, create_secondary_model
+
+__all__ = [
+ 'VisionTransformer',
+ 'TextTransformer',
+ 'CLIPTextModel',
+ 'CLIPVisionModel',
+ 'CLIPPretrainedModel',
+ 'CLIPModel',
+ 'CLIPForImageGeneration',
+ 'ModifiedResNet',
+]
+
+
+@dataclass
+class CLIPOutput(ModelOutput):
+ """
+ Args:
+ loss (`paddle.Tensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
+ Contrastive loss for image-text similarity.
+ logits_per_image:(`paddle.Tensor` of shape `(image_batch_size, text_batch_size)`):
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
+ similarity scores.
+ logits_per_text:(`paddle.Tensor` of shape `(text_batch_size, image_batch_size)`):
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
+ similarity scores.
+ text_embeds(`paddle.Tensor` of shape `(batch_size, output_dim`):
+ The text embeddings obtained by applying the projection layer to the pooled output of [`TextTransformer`].
+ image_embeds(`paddle.Tensor` of shape `(batch_size, output_dim`):
+ The image embeddings obtained by applying the projection layer to the pooled output of [`VisionTransformer`].
+ text_model_output(:class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions`):
+ The output of the [`TextTransformer`].
+ vision_model_output(:class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions`):
+ The output of the [`VisionTransformer`].
+ """
+
+ loss: Optional[paddle.Tensor] = None
+ logits_per_image: paddle.Tensor = None
+ logits_per_text: paddle.Tensor = None
+ text_embeds: paddle.Tensor = None
+ image_embeds: paddle.Tensor = None
+ text_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None
+ vision_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None
+
+ def to_tuple(self) -> Tuple[Any]:
+ return tuple(
+ self[k] if k not in ["text_model_output", "vision_model_output"
+ ] else getattr(self, k).to_tuple()
+ for k in self.keys())
+
+
+# contrastive loss function, adapted from
+# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
+def contrastive_loss(logits):
+ return F.cross_entropy(logits, paddle.arange(len(logits)))
+
+
+def clip_loss(similarity):
+ caption_loss = contrastive_loss(similarity)
+ image_loss = contrastive_loss(similarity.t())
+ return (caption_loss + image_loss) / 2.0
+
+
+class ModifiedResNet(nn.Layer):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self,
+ layers,
+ output_dim,
+ heads,
+ input_resolution=224,
+ width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.input_resolution = input_resolution
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2D(3,
+ width // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias_attr=False)
+ self.bn1 = nn.BatchNorm2D(width // 2)
+ self.conv2 = nn.Conv2D(width // 2,
+ width // 2,
+ kernel_size=3,
+ padding=1,
+ bias_attr=False)
+ self.bn2 = nn.BatchNorm2D(width // 2)
+ self.conv3 = nn.Conv2D(width // 2,
+ width,
+ kernel_size=3,
+ padding=1,
+ bias_attr=False)
+ self.bn3 = nn.BatchNorm2D(width)
+ self.avgpool = nn.AvgPool2D(2)
+ self.relu = nn.ReLU()
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
+ heads, output_dim)
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+
+ def stem(x):
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
+ (self.conv3, self.bn3)]:
+ x = self.relu(bn(conv(x)))
+ x = self.avgpool(x)
+ return x
+
+ x = stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
+
+
+def multi_head_attention_forward(x: paddle.Tensor,
+ num_heads: int,
+ q_proj: nn.Linear,
+ k_proj: nn.Linear,
+ v_proj: nn.Linear,
+ c_proj: nn.Linear,
+ attn_mask: Optional[paddle.Tensor] = None):
+ max_len, batch_size, emb_dim = x.shape
+ head_dim = emb_dim // num_heads
+ scaling = float(head_dim)**-0.5
+ q = q_proj(x) # L, N, E
+ k = k_proj(x) # L, N, E
+ v = v_proj(x) # L, N, E
+
+ v = v.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2))
+ k = k.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2))
+ q = q.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2))
+
+ q = q * scaling
+ qk = paddle.matmul(q, k, transpose_y=True)
+ if attn_mask is not None:
+ if attn_mask.ndim == 2:
+ attn_mask.unsqueeze_(0)
+ assert attn_mask.shape[0] == 1 and attn_mask.shape[
+ 1] == max_len and attn_mask.shape[2] == max_len
+ qk += attn_mask
+
+ qk = F.softmax(qk, axis=-1)
+ atten = paddle.bmm(qk, v)
+ atten = atten.transpose((1, 0, 2))
+ atten = atten.reshape((max_len, batch_size, emb_dim))
+ atten = c_proj(atten)
+ return atten
+
+
+class Identity(nn.Layer):
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x
+
+
+class Bottleneck(nn.Layer):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2D(inplanes, planes, 1, bias_attr=False)
+ self.bn1 = nn.BatchNorm2D(planes)
+
+ self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)
+ self.bn2 = nn.BatchNorm2D(planes)
+
+ self.avgpool = nn.AvgPool2D(stride) if stride > 1 else Identity()
+
+ self.conv3 = nn.Conv2D(planes,
+ planes * self.expansion,
+ 1,
+ bias_attr=False)
+ self.bn3 = nn.BatchNorm2D(planes * self.expansion)
+
+ self.relu = nn.ReLU()
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ self.downsample = nn.Sequential(
+ ("-1", nn.AvgPool2D(stride)),
+ ("0",
+ nn.Conv2D(inplanes,
+ planes * self.expansion,
+ 1,
+ stride=1,
+ bias_attr=False)),
+ ("1", nn.BatchNorm2D(planes * self.expansion)))
+
+ def forward(self, x):
+ identity = x
+
+ out = self.relu(self.bn1(self.conv1(x)))
+ out = self.relu(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+ return out
+
+
+class AttentionPool2d(nn.Layer):
+
+ def __init__(self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads: int,
+ output_dim: int = None):
+ super().__init__()
+
+ self.positional_embedding = nn.Embedding(spacial_dim**2 + 1, embed_dim)
+
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias_attr=True)
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias_attr=True)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias_attr=True)
+ self.c_proj = nn.Linear(embed_dim,
+ output_dim or embed_dim,
+ bias_attr=True)
+ self.num_heads = num_heads
+
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
+
+ def forward(self, x):
+
+ x = x.reshape(
+ (x.shape[0], x.shape[1], x.shape[2] * x.shape[3])).transpose(
+ (2, 0, 1)) # NCHW -> (HW)NC
+ x = paddle.concat([x.mean(axis=0, keepdim=True), x], axis=0)
+ x = x + paddle.unsqueeze(self.positional_embedding.weight, 1)
+ out = multi_head_attention_forward(x, self.num_heads, self.q_proj,
+ self.k_proj, self.v_proj,
+ self.c_proj)
+
+ return out[0]
+
+
+class CLIPPretrainedModel(PretrainedModel):
+ """
+ An abstract class for pretrained CLIP models. It provides CLIP related
+ `model_config_file`, `pretrained_init_configuration`, `resource_files_names`,
+ `pretrained_resource_files_map`, `base_model_prefix` for downloading and
+ loading pretrained models.
+ See :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details.
+ """
+ model_config_file = "model_config.json"
+ pretrained_init_configuration = {
+ "openai/clip-vit-base-patch32": {
+ # vision
+ "image_resolution": 224,
+ "vision_layers": 12,
+ "vision_heads": 12,
+ "vision_mlp_ratio": 4,
+ "vision_embed_dim": 768,
+ "vision_patch_size": 32,
+ "vision_hidden_act": "quick_gelu",
+ # text
+ "max_text_length": 77,
+ "vocab_size": 49408,
+ "text_embed_dim": 512,
+ "text_heads": 8,
+ "text_layers": 12,
+ "text_hidden_act": "quick_gelu",
+ # others
+ "projection_dim": 512,
+ "initializer_range": 0.02,
+ "logit_scale_init_value": 2.6592
+ },
+ "openai/clip-rn50": {
+ # vision
+ "image_resolution": 224,
+ "vision_layers": [3, 4, 6, 3],
+ "vision_heads": 32,
+ "vision_mlp_ratio": None, # do not use
+ "vision_embed_dim": 64, # vision width
+ "vision_patch_size": None, # do not use
+ "vision_hidden_act": None, # do not use
+ # text
+ "max_text_length": 77,
+ "vocab_size": 49408,
+ "text_embed_dim": 512,
+ "text_heads": 8,
+ "text_layers": 12,
+ "text_hidden_act": "quick_gelu",
+ # others
+ "projection_dim": 1024,
+ "initializer_range": 0.02,
+ "logit_scale_init_value": 2.6592
+ },
+ "openai/clip-rn101": {
+ # vision
+ "image_resolution": 224,
+ "vision_layers": [3, 4, 23, 3],
+ "vision_heads": 32,
+ "vision_mlp_ratio": None, # do not use
+ "vision_embed_dim": 64, # vision width
+ "vision_patch_size": None, # do not use
+ "vision_hidden_act": None, # do not use
+ # text
+ "max_text_length": 77,
+ "vocab_size": 49408,
+ "text_embed_dim": 512,
+ "text_heads": 8,
+ "text_layers": 12,
+ "text_hidden_act": "quick_gelu",
+ # others
+ "projection_dim": 512,
+ "initializer_range": 0.02,
+ "logit_scale_init_value": 2.6592
+ },
+ "openai/clip-vit-large-patch14": {
+ # vision
+ "image_resolution": 224,
+ "vision_layers": 24,
+ "vision_heads": 16,
+ "vision_mlp_ratio": 4,
+ "vision_embed_dim": 1024,
+ "vision_patch_size": 14,
+ "vision_hidden_act": "quick_gelu",
+ # text
+ "max_text_length": 77,
+ "vocab_size": 49408,
+ "text_embed_dim": 768,
+ "text_heads": 12,
+ "text_layers": 12,
+ "text_hidden_act": "quick_gelu",
+ # others
+ "projection_dim": 768,
+ "initializer_range": 0.02,
+ "logit_scale_init_value": 2.6592
+ },
+ }
+ resource_files_names = {"model_state": "model_state.pdparams"}
+ pretrained_resource_files_map = {
+ "model_state": {
+ "openai/clip-vit-base-patch32":
+ "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-vit-base-patch32/model_state.pdparams",
+ "openai/clip-rn50":
+ "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-rn50/model_state.pdparams",
+ "openai/clip-rn101":
+ "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-rn101/model_state.pdparams",
+ "openai/clip-vit-large-patch14":
+ "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-vit-large-patch14/model_state.pdparams",
+ }
+ }
+ base_model_prefix = "clip"
+
+ def _init_weights(self, layer):
+ """Initialize the weights"""
+ initializer_range = self.initializer_range if hasattr(
+ self,
+ "initializer_range") else self.clip.config["initializer_range"]
+ factor = self.initializer_factor if hasattr(
+ self,
+ "initializer_factor") else self.clip.config["initializer_factor"]
+
+ if isinstance(layer, VisionTransformer):
+ vision_embed_dim = self.vision_embed_dim if hasattr(
+ self,
+ "vision_embed_dim") else self.clip.config["vision_embed_dim"]
+ vision_layers = self.vision_layers if hasattr(
+ self, "vision_layers") else self.clip.config["vision_layers"]
+ # vision embedding
+ layer.class_embedding.set_value(
+ paddle.normal(
+ std=vision_embed_dim**-0.5 * factor,
+ shape=layer.class_embedding.shape,
+ ))
+ layer.conv1.weight.set_value(
+ paddle.normal(
+ std=initializer_range * factor,
+ shape=layer.conv1.weight.shape,
+ ))
+ layer.positional_embedding.weight.set_value(
+ paddle.normal(
+ std=initializer_range * factor,
+ shape=layer.positional_embedding.weight.shape,
+ ))
+
+ elif isinstance(layer, TextTransformer):
+ text_embed_dim = self.text_embed_dim if hasattr(
+ self, "text_embed_dim") else self.clip.config["text_embed_dim"]
+ text_layers = self.text_layers if hasattr(
+ self, "text_layers") else self.clip.config["text_layers"]
+ # text embedding
+ layer.token_embedding.weight.set_value(
+ paddle.normal(
+ mean=0.0,
+ std=factor * 0.02,
+ shape=layer.token_embedding.weight.shape,
+ ))
+ layer.positional_embedding.weight.set_value(
+ paddle.normal(
+ mean=0.0,
+ std=factor * 0.02,
+ shape=layer.positional_embedding.weight.shape,
+ ))
+ elif isinstance(layer, CLIPModel):
+ vision_embed_dim = self.vision_embed_dim if hasattr(
+ self,
+ "vision_embed_dim") else self.clip.config["vision_embed_dim"]
+ vision_layers = self.vision_layers if hasattr(
+ self, "vision_layers") else self.clip.config["vision_layers"]
+ text_embed_dim = self.text_embed_dim if hasattr(
+ self, "text_embed_dim") else self.clip.config["text_embed_dim"]
+ text_layers = self.text_layers if hasattr(
+ self, "text_layers") else self.clip.config["text_layers"]
+ layer.text_projection.set_value(
+ paddle.normal(
+ std=text_embed_dim**-0.5 * factor,
+ shape=layer.text_projection.shape,
+ ))
+ if hasattr(layer, "vision_projection"):
+ layer.vision_projection.set_value(
+ paddle.normal(
+ std=vision_embed_dim**-0.5 * factor,
+ shape=layer.vision_projection.shape,
+ ))
+ for name, sub_layer in layer.named_sublayers():
+ num_layers = vision_layers if "vision_model" in name else text_layers
+ if isinstance(sub_layer, nn.TransformerEncoderLayer):
+ # self_attn
+ in_proj_std = (sub_layer.self_attn.embed_dim**-0.5) * (
+ (2 * num_layers)**-0.5) * factor
+ out_proj_std = (sub_layer.self_attn.embed_dim**
+ -0.5) * factor
+ sub_layer.self_attn.q_proj.weight.set_value(
+ paddle.normal(
+ std=in_proj_std,
+ shape=sub_layer.self_attn.q_proj.weight.shape,
+ ))
+ sub_layer.self_attn.k_proj.weight.set_value(
+ paddle.normal(
+ std=in_proj_std,
+ shape=sub_layer.self_attn.k_proj.weight.shape,
+ ))
+ sub_layer.self_attn.v_proj.weight.set_value(
+ paddle.normal(
+ std=in_proj_std,
+ shape=sub_layer.self_attn.v_proj.weight.shape,
+ ))
+ sub_layer.self_attn.out_proj.weight.set_value(
+ paddle.normal(
+ std=out_proj_std,
+ shape=sub_layer.self_attn.out_proj.weight.shape,
+ ))
+ # ffn
+ in_proj_std = ((sub_layer._config["d_model"]**-0.5) *
+ ((2 * num_layers)**-0.5) * factor)
+ fc_std = (2 * sub_layer._config["d_model"])**-0.5 * factor
+ sub_layer.linear1.weight.set_value(
+ paddle.normal(
+ std=fc_std,
+ shape=sub_layer.linear1.weight.shape,
+ ))
+ sub_layer.linear2.weight.set_value(
+ paddle.normal(
+ std=in_proj_std,
+ shape=sub_layer.linear2.weight.shape,
+ ))
+ if isinstance(layer, nn.LayerNorm):
+ layer.bias.set_value(paddle.zeros_like(layer.bias))
+ layer.weight.set_value(paddle.ones_like(layer.weight))
+ if isinstance(layer, nn.Linear) and layer.bias is not None:
+ layer.bias.set_value(paddle.zeros_like(layer.bias))
+
+
+# set attr
+def quick_gelu(x):
+ return x * F.sigmoid(1.702 * x)
+
+
+F.quick_gelu = quick_gelu
+
+NEG_INF = float("-inf") # -1e4 -1e9
+
+
+class VisionTransformer(nn.Layer):
+
+ def __init__(self,
+ input_resolution: int,
+ patch_size: int,
+ width: int,
+ layers: int,
+ heads: int,
+ activation: str,
+ mlp_ratio: int,
+ normalize_before: bool = True):
+ super().__init__()
+ self.input_resolution = input_resolution
+ # used patch_size x patch_size, stride patch_size to do linear projection
+ self.conv1 = nn.Conv2D(in_channels=3,
+ out_channels=width,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias_attr=False)
+
+ self.class_embedding = paddle.create_parameter(
+ (width, ), paddle.get_default_dtype())
+
+ self.positional_embedding = nn.Embedding(
+ (input_resolution // patch_size)**2 + 1, width)
+
+ self.ln_pre = nn.LayerNorm(width)
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model=width,
+ nhead=heads,
+ dim_feedforward=width * mlp_ratio,
+ normalize_before=normalize_before,
+ dropout=0,
+ activation=activation,
+ attn_dropout=0,
+ act_dropout=0)
+ self.transformer = nn.TransformerEncoder(encoder_layer, layers)
+
+ self.ln_post = nn.LayerNorm(width)
+
+ def forward(self,
+ pixel_values,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=False):
+ pixel_values = self.conv1(pixel_values)
+ pixel_values = pixel_values.reshape(
+ (pixel_values.shape[0], pixel_values.shape[1], -1))
+ pixel_values = pixel_values.transpose((0, 2, 1))
+ embedding_output = paddle.concat([
+ self.class_embedding.unsqueeze([0, 1]).expand(
+ [pixel_values.shape[0], -1, -1]), pixel_values
+ ],
+ axis=1)
+ embedding_output = embedding_output + self.positional_embedding.weight
+ embedding_output = self.ln_pre(embedding_output)
+ encoder_outputs = self.transformer(
+ embedding_output,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict)
+
+ if isinstance(encoder_outputs, type(embedding_output)):
+ last_hidden_state = encoder_outputs
+ else:
+ last_hidden_state = encoder_outputs[0]
+
+ pooled_output = self.ln_post(last_hidden_state[:, 0])
+
+ if isinstance(encoder_outputs, type(embedding_output)):
+ return (last_hidden_state, pooled_output)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions)
+
+
+class TextTransformer(nn.Layer):
+
+ def __init__(self,
+ context_length,
+ transformer_width,
+ transformer_heads,
+ transformer_layers,
+ vocab_size,
+ activation="quick_gelu",
+ normalize_before=True):
+ super().__init__()
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model=transformer_width,
+ nhead=transformer_heads,
+ dim_feedforward=transformer_width * 4,
+ normalize_before=normalize_before,
+ dropout=0,
+ activation=activation,
+ attn_dropout=0,
+ act_dropout=0)
+ self.transformer = nn.TransformerEncoder(encoder_layer,
+ transformer_layers)
+
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+ self.positional_embedding = nn.Embedding(context_length,
+ transformer_width)
+ self.ln_final = nn.LayerNorm(transformer_width)
+
+ self.register_buffer("causal_mask",
+ paddle.triu(paddle.ones(
+ (1, 1, context_length, context_length)) *
+ NEG_INF,
+ diagonal=1),
+ persistable=False)
+ self.register_buffer("position_ids",
+ paddle.arange(context_length).reshape((1, -1)),
+ persistable=False)
+
+ def forward(self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=False):
+ bs, seqlen = input_ids.shape[:2]
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seqlen].astype("int64")
+
+ causal_mask = self.causal_mask[:, :, :seqlen, :seqlen]
+ if attention_mask is not None:
+ assert attention_mask.ndim == 2
+ expanded_mask = attention_mask[:, None, None, :].expand(
+ [bs, 1, seqlen, -1]).astype(causal_mask.dtype)
+ inverted_mask = (1.0 - expanded_mask) * NEG_INF
+ attention_mask = inverted_mask + causal_mask
+ else:
+ attention_mask = causal_mask
+ attention_mask.stop_gradient = True
+
+ embedding_output = self.token_embedding(
+ input_ids) + self.positional_embedding(
+ position_ids) # [batch_size, n_ctx, d_model]
+
+ encoder_outputs = self.transformer(
+ embedding_output,
+ src_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict)
+
+ if isinstance(encoder_outputs, type(embedding_output)):
+ last_hidden_state = encoder_outputs
+ else:
+ last_hidden_state = encoder_outputs[0]
+
+ last_hidden_state = self.ln_final(last_hidden_state)
+ pooled_output = last_hidden_state.gather_nd(
+ paddle.stack(
+ [paddle.arange(bs), input_ids.argmax(-1)], axis=-1))
+
+ if isinstance(encoder_outputs, type(embedding_output)):
+ return (last_hidden_state, pooled_output)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions)
+
+
+@register_base_model
+class CLIPModel(CLIPPretrainedModel):
+ r"""
+ The bare CLIP Model outputting logits_per_image and logits_per_text.
+ This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`.
+ Refer to the superclass documentation for the generic methods.
+ This model is also a Paddle `paddle.nn.Layer `__ subclass. Use it as a regular Paddle Layer
+ and refer to the Paddle documentation for all matter related to general usage and behavior.
+
+ Args:
+ image_resolution (int, optional):
+ The size (resolution) of each image.
+ Defaults to `224`.
+ vision_layers (int, optional):
+ Number of hidden layers in the vision model.
+ Defaults to `12`.
+ vision_heads (int, optional):
+ Number of attention heads for each attention layer in the vision attention.
+ Defaults to `12`.
+ vision_embed_dim (int, optional):
+ Dimensionality of the embedding layer and encoder layers in vision model.
+ Defaults to `768`.
+ vision_patch_size(int, optional):
+ The size (resolution) of each patch.
+ Defaults to `32`.
+ vision_mlp_ratio(int, optional):
+ The ratio between dim_feedforward and vision_hidden_dim. `radio = dim_feedforward/vision_hidden_dim`
+ Defaults to `4`.
+ vision_hidden_act (str, optional):
+ The non-linear activation function of the ffn layer in the vision model.
+ ``"gelu"``, ``"relu"``, ``"quick_gelu"`` and any other paddle supported activation functions are supported.
+ Defaults to `"quick_gelu"`.
+ max_text_length (int, optional):
+ The maximum value of the dimensionality of text position encoding, which dictates the maximum supported length of the text
+ input sequence. Defaults to `64`.
+ vocab_size (int, optional):
+ Vocabulary size of `inputs_ids` in `CLIPModel`. Also is the vocab size of text token embedding matrix.
+ Defaults to `49408`.
+ text_embed_dim (int, optional):
+ Dimensionality of the embedding layer and encoder layers in text model.
+ Defaults to `768`.
+ text_heads (int, optional):
+ Number of attention heads for each attention layer in the text attention.
+ Defaults to `8`.
+ text_layers (int, optional):
+ Number of hidden layers in the text model.
+ Defaults to `12`.
+ text_hidden_act (str, optional):
+ The non-linear activation function of the ffn layer in the text model.
+ ``"gelu"``, ``"relu"``, ``"quick_gelu"`` and any other paddle supported activation functions are supported.
+ Defaults to `"quick_gelu"`.
+ projection_dim (int, optional):
+ Dimentionality of text and vision projection layers.
+ Defaults to `512`.
+ initializer_range (float, optional):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ Default to `0.02`.
+ initializer_factor (float, optional):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing). Default to `1.`.
+ logit_scale_init_value (float, optional):
+ The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation.
+ Default to `2.6592`.
+
+ """
+
+ def __init__(
+ self,
+ # vision
+ image_resolution: int = 224,
+ vision_layers: Union[Tuple[int, int, int, int], int] = 12,
+ vision_heads: int = 12,
+ vision_embed_dim: int = 768,
+ vision_patch_size: int = 32,
+ vision_mlp_ratio: int = 4,
+ vision_hidden_act: str = "quick_gelu",
+ # text
+ max_text_length: int = 77,
+ vocab_size: int = 49408,
+ text_embed_dim: int = 512,
+ text_heads: int = 8,
+ text_layers: int = 12,
+ text_hidden_act: str = "quick_gelu",
+ # others
+ projection_dim: int = 512,
+ initializer_range: float = 0.02,
+ initializer_factor: float = 1.0,
+ logit_scale_init_value: float = 2.6592):
+ super().__init__()
+ self.initializer_factor = initializer_factor
+ self.initializer_range = initializer_range
+ self.logit_scale_init_value = logit_scale_init_value
+ self.vision_embed_dim = vision_embed_dim
+ self.text_embed_dim = text_embed_dim
+ self.vision_layers = vision_layers
+ self.text_layers = text_layers
+
+ if isinstance(vision_layers, (tuple, list)):
+ if vision_heads is None:
+ vision_heads = vision_embed_dim * 32 // 64
+ self.vision_model = ModifiedResNet(
+ layers=vision_layers,
+ output_dim=projection_dim,
+ heads=vision_heads,
+ input_resolution=image_resolution,
+ width=vision_embed_dim)
+ else:
+ if vision_heads is None:
+ vision_heads = vision_embed_dim // 64
+ self.vision_model = VisionTransformer(
+ input_resolution=image_resolution,
+ patch_size=vision_patch_size,
+ width=vision_embed_dim,
+ layers=vision_layers,
+ heads=vision_heads,
+ activation=vision_hidden_act,
+ mlp_ratio=vision_mlp_ratio,
+ normalize_before=True)
+ self.vision_projection = paddle.create_parameter(
+ (vision_embed_dim, projection_dim), paddle.get_default_dtype())
+
+ self.text_model = TextTransformer(context_length=max_text_length,
+ transformer_width=text_embed_dim,
+ transformer_heads=text_heads,
+ transformer_layers=text_layers,
+ vocab_size=vocab_size,
+ activation=text_hidden_act,
+ normalize_before=True)
+
+ self.text_projection = paddle.create_parameter(
+ (text_embed_dim, projection_dim), paddle.get_default_dtype())
+
+ self.logit_scale = paddle.create_parameter(
+ (1, ),
+ dtype=paddle.get_default_dtype(),
+ default_initializer=nn.initializer.Constant(logit_scale_init_value))
+ self.apply(self._init_weights)
+
+ def get_image_features(self,
+ pixel_values=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=False):
+ r"""
+ Returns:
+ image_features (`paddle.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
+ applying the projection layer to the pooled output of [`VisionTransformer`].
+
+ Examples:
+ .. code-block::
+
+ import requests
+ from PIL import Image
+ from paddlenlp.transformers import CLIPProcessor, CLIPModel
+
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
+
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ image = Image.open(requests.get(url, stream=True).raw)
+ inputs = processor(images=image, return_tensors="pd")
+ image_features = model.get_image_features(**inputs)
+
+ """
+ if isinstance(self.vision_model, ModifiedResNet):
+ return self.vision_model(pixel_values)
+ else:
+ vision_outputs = self.vision_model(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict)
+ pooled_output = vision_outputs[1]
+ image_features = paddle.matmul(pooled_output,
+ self.vision_projection)
+ return image_features
+
+ def get_text_features(
+ self,
+ input_ids,
+ attention_mask=None,
+ position_ids=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=False,
+ ):
+ r"""
+ Returns:
+ text_features (`paddle.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
+ applying the projection layer to the pooled output of [`TextTransformer`].
+
+ Example:
+ .. code-block::
+
+ from paddlenlp.transformers import CLIPModel, CLIPTokenizer
+
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
+
+ inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pd")
+ text_features = model.get_text_features(**inputs)
+
+ """
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict)
+ pooled_output = text_outputs[1]
+ text_features = paddle.matmul(pooled_output, self.text_projection)
+ return text_features
+
+ def forward(self,
+ input_ids,
+ pixel_values,
+ attention_mask=None,
+ position_ids=None,
+ return_loss=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=False):
+ r'''
+ The CLIPModel forward method, overrides the `__call__()` special method.
+
+ Args:
+ input_ids (Tensor):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
+ Its data type should be `int64` and it has a shape of [text_batch_size, sequence_length].
+ pixel_values (Tensor):
+ Pixel values. Padding will be ignored by default should you provide it.
+ Its data type should be `float32` and it has a shape of [image_batch_size, num_channels, height, width].
+ position_ids(Tensor, optional):
+ Indices of positions of each input sequence tokens in the position embeddings (TextTransformer). Selected in
+ the range ``[0, max_text_length - 1]``.
+ Shape as `(batch_size, num_tokens)` and dtype as int64. Defaults to `None`.
+ attention_mask (Tensor, optional):
+ Mask used in multi-head attention (TextTransformer) to avoid performing attention on to some unwanted positions,
+ usually the paddings or the subsequent positions.
+ Its data type can be int, float and bool.
+ When the data type is bool, the `masked` tokens have `False` values and the others have `True` values.
+ When the data type is int, the `masked` tokens have `0` values and the others have `1` values.
+ When the data type is float, the `masked` tokens have `0.0` values and the others have `1.0` values.
+ It is a tensor with shape `[batch_size, sequence_length`.
+ Defaults to `None`, which means nothing needed to be prevented attention to.
+ output_hidden_states (bool, optional):
+ Whether to return the hidden states of all layers.
+ Defaults to `False`.
+ output_attentions (bool, optional):
+ Whether to return the attentions tensors of all attention layers.
+ Defaults to `False`.
+ return_dict (bool, optional):
+ Whether to return a :class:`CLIPOutput` object. If `False`, the output
+ will be a tuple of tensors. Defaults to `False`.
+
+ Returns:
+ An instance of :class:`CLIPOutput` if `return_dict=True`. Otherwise it returns a tuple of tensors
+ corresponding to ordered and not None (depending on the input arguments) fields of :class:`CLIPOutput`.
+
+ Example:
+ .. code-block::
+
+ import requests
+ import paddle.nn.functional as F
+ from PIL import Image
+ from paddlenlp.transformers import CLIPModel, CLIPProcessor
+
+ processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
+ model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
+
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ image = Image.open(requests.get(url, stream=True).raw)
+
+ inputs = processor(text=["a photo of a cat", "a photo of a dog"],
+ images=image,
+ padding=True,
+ return_tensors="pd")
+
+ outputs = model(**inputs)
+
+ logits_per_image = outputs[0]
+ probs = F.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities
+
+ '''
+ if isinstance(self.vision_model, ModifiedResNet):
+ vision_outputs = None
+ image_embeds = self.vision_model(pixel_values)
+ else:
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ image_embeds = paddle.matmul(vision_outputs[1],
+ self.vision_projection)
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ text_embeds = paddle.matmul(text_outputs[1], self.text_projection)
+
+ # normalized features
+ image_embeds = image_embeds / image_embeds.norm(axis=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(axis=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_text = paddle.matmul(
+ text_embeds, image_embeds, transpose_y=True) * logit_scale
+ logits_per_image = logits_per_text.t()
+
+ loss = None
+
+ if return_loss:
+ loss = clip_loss(logits_per_text)
+
+ if not return_dict:
+ output = (logits_per_image, logits_per_text, text_embeds,
+ image_embeds, text_outputs, vision_outputs)
+ return ((loss, ) + output) if loss is not None else output
+
+ return CLIPOutput(
+ loss=loss,
+ logits_per_image=logits_per_image,
+ logits_per_text=logits_per_text,
+ text_embeds=text_embeds,
+ image_embeds=image_embeds,
+ text_model_output=text_outputs,
+ vision_model_output=vision_outputs,
+ )
+
+
+class CLIPTextPretrainedModel(CLIPPretrainedModel):
+ pass
+
+
+@register_base_model
+class CLIPTextModel(CLIPTextPretrainedModel):
+ r"""
+ The bare CLIPTextModel outputting :class:`BaseModelOutputWithPoolingAndCrossAttentions`.
+ This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`.
+ Refer to the superclass documentation for the generic methods.
+ This model is also a Paddle `paddle.nn.Layer `__ subclass. Use it as a regular Paddle Layer
+ and refer to the Paddle documentation for all matter related to general usage and behavior.
+
+ Args:
+ max_text_length (int, optional):
+ The maximum value of the dimensionality of text position encoding, which dictates the maximum supported length of the text
+ input sequence. Defaults to `64`.
+ vocab_size (int, optional):
+ Vocabulary size of `inputs_ids` in `CLIPModel`. Also is the vocab size of text token embedding matrix.
+ Defaults to `49408`.
+ text_embed_dim (int, optional):
+ Dimensionality of the embedding layer and encoder layers in text model.
+ Defaults to `768`.
+ text_heads (int, optional):
+ Number of attention heads for each attention layer in the text attention.
+ Defaults to `8`.
+ text_layers (int, optional):
+ Number of hidden layers in the text model.
+ Defaults to `12`.
+ text_hidden_act (str, optional):
+ The non-linear activation function of the ffn layer in the text model.
+ ``"gelu"``, ``"relu"``, ``"quick_gelu"`` and any other paddle supported activation functions are supported.
+ Defaults to `"quick_gelu"`.
+ initializer_range (float, optional):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ Default to `0.02`.
+ initializer_factor (float, optional):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing). Default to `1.`.
+
+ """
+
+ def __init__(self,
+ max_text_length=77,
+ text_embed_dim=512,
+ text_heads=8,
+ text_layers=12,
+ vocab_size=49408,
+ text_hidden_act="quick_gelu",
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ **kwargs):
+ super().__init__()
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.text_embed_dim = text_embed_dim
+ self.text_layers = text_layers
+ self.text_model = TextTransformer(context_length=max_text_length,
+ transformer_width=text_embed_dim,
+ transformer_heads=text_heads,
+ transformer_layers=text_layers,
+ vocab_size=vocab_size,
+ activation=text_hidden_act,
+ normalize_before=True)
+ self.apply(self._init_weights)
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=False,
+ ):
+ r"""
+ The CLIPTextModel forward method, overrides the `__call__()` special method.
+
+ Args:
+ input_ids (Tensor):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
+ Its data type should be `int64` and it has a shape of [text_batch_size, sequence_length].
+ attention_mask (Tensor, optional):
+ Mask used in multi-head attention (TextTransformer) to avoid performing attention on to some unwanted positions,
+ usually the paddings or the subsequent positions.
+ Its data type can be int, float and bool.
+ When the data type is bool, the `masked` tokens have `False` values and the others have `True` values.
+ When the data type is int, the `masked` tokens have `0` values and the others have `1` values.
+ When the data type is float, the `masked` tokens have `0.0` values and the others have `1.0` values.
+ It is a tensor with shape `[batch_size, sequence_length`.
+ Defaults to `None`, which means nothing needed to be prevented attention to.
+ position_ids(Tensor, optional):
+ Indices of positions of each input sequence tokens in the position embeddings (TextTransformer). Selected in
+ the range ``[0, max_text_length - 1]``.
+ Shape as `(batch_size, num_tokens)` and dtype as int64. Defaults to `None`.
+ output_hidden_states (bool, optional):
+ Whether to return the hidden states of all layers.
+ Defaults to `False`.
+ output_attentions (bool, optional):
+ Whether to return the attentions tensors of all attention layers.
+ Defaults to `False`.
+ return_dict (bool, optional):
+ Whether to return a :class:`BaseModelOutputWithPoolingAndCrossAttentions` object. If `False`, the output
+ will be a tuple of tensors. Defaults to `False`.
+
+ Returns:
+ An instance of :class:`BaseModelOutputWithPoolingAndCrossAttentions` if `return_dict=True`. Otherwise it returns a tuple of tensors
+ corresponding to ordered and not None (depending on the input arguments) fields of :class:`BaseModelOutputWithPoolingAndCrossAttentions`.
+
+ Example:
+ .. code-block::
+
+ from paddlenlp.transformers import CLIPTokenizer, CLIPTextModel
+
+ model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
+
+ inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pd")
+ outputs = model(**inputs)
+ last_hidden_state = outputs.last_hidden_state
+ pooled_output = outputs.pooler_output
+
+ """
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+class CLIPVisionPretrainedModel(CLIPPretrainedModel):
+ pass
+
+
+@register_base_model
+class CLIPVisionModel(CLIPVisionPretrainedModel):
+ r"""
+ The bare CLIPVisionModel outputting :class:`BaseModelOutputWithPoolingAndCrossAttentions`.
+ This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`.
+ Refer to the superclass documentation for the generic methods.
+ This model is also a Paddle `paddle.nn.Layer `__ subclass. Use it as a regular Paddle Layer
+ and refer to the Paddle documentation for all matter related to general usage and behavior.
+
+ Args:
+ image_resolution (int, optional):
+ The size (resolution) of each image.
+ Defaults to `224`.
+ vision_layers (int, optional):
+ Number of hidden layers in the vision model.
+ Defaults to `12`.
+ vision_heads (int, optional):
+ Number of attention heads for each attention layer in the vision attention.
+ Defaults to `12`.
+ vision_embed_dim (int, optional):
+ Dimensionality of the embedding layer and encoder layers in vision model.
+ Defaults to `768`.
+ vision_patch_size(int, optional):
+ The size (resolution) of each patch.
+ Defaults to `32`.
+ vision_mlp_ratio(int, optional):
+ The ratio between dim_feedforward and vision_hidden_dim. `radio = dim_feedforward/vision_hidden_dim`
+ Defaults to `4`.
+ vision_hidden_act (str, optional):
+ The non-linear activation function of the ffn layer in the vision model.
+ ``"gelu"``, ``"relu"``, ``"quick_gelu"`` and any other paddle supported activation functions are supported.
+ Defaults to `"quick_gelu"`.
+ initializer_range (float, optional):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ Default to `0.02`.
+ initializer_factor (float, optional):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing). Default to `1.`.
+
+ """
+
+ def __init__(self,
+ image_resolution=224,
+ vision_patch_size=32,
+ vision_embed_dim=768,
+ vision_layers=12,
+ vision_heads=12,
+ vision_hidden_act="quick_gelu",
+ vision_mlp_ratio=4,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ **kwargs):
+ super().__init__()
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.vision_embed_dim = vision_embed_dim
+ self.vision_layers = vision_layers
+
+ if vision_heads is None:
+ vision_heads = vision_embed_dim // 64
+ self.vision_model = VisionTransformer(input_resolution=image_resolution,
+ patch_size=vision_patch_size,
+ width=vision_embed_dim,
+ layers=vision_layers,
+ heads=vision_heads,
+ activation=vision_hidden_act,
+ mlp_ratio=vision_mlp_ratio,
+ normalize_before=True)
+
+ self.apply(self._init_weights)
+
+ def forward(
+ self,
+ pixel_values=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=False,
+ ):
+ r'''
+ The CLIPVisionModel forward method, overrides the `__call__()` special method.
+
+ Args:
+ pixel_values (Tensor):
+ Pixel values. Padding will be ignored by default should you provide it.
+ Its data type should be `float32` and it has a shape of [image_batch_size, num_channels, height, width].
+ output_hidden_states (bool, optional):
+ Whether to return the hidden states of all layers.
+ Defaults to `False`.
+ output_attentions (bool, optional):
+ Whether to return the attentions tensors of all attention layers.
+ Defaults to `False`.
+ return_dict (bool, optional):
+ Whether to return a :class:`BaseModelOutputWithPoolingAndCrossAttentions` object. If `False`, the output
+ will be a tuple of tensors. Defaults to `False`.
+
+ Returns:
+ An instance of :class:`BaseModelOutputWithPoolingAndCrossAttentions` if `return_dict=True`. Otherwise it returns a tuple of tensors
+ corresponding to ordered and not None (depending on the input arguments) fields of :class:`BaseModelOutputWithPoolingAndCrossAttentions`.
+
+ Example:
+ .. code-block::
+
+ from PIL import Image
+ import requests
+ from paddlenlp.transformers import CLIPProcessor, CLIPVisionModel
+
+ model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
+
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ image = Image.open(requests.get(url, stream=True).raw)
+ inputs = processor(images=image, return_tensors="pd")
+ outputs = model(**inputs)
+
+ last_hidden_state = outputs.last_hidden_state
+ pooled_output = outputs.pooler_output # pooled CLS states
+
+ '''
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+class CLIPForImageGeneration(CLIPPretrainedModel, DiscoDiffusionMixin,
+ StableDiffusionMixin):
+ r"""
+ CLIP Model with diffusion model on top.
+
+ Args:
+ clip (:class:`CLIPModel`):
+ An instance of CLIPModel.
+ diffusion_type (str, optional):
+ The type of diffusion. Please choose in ['disco', 'stable'].
+ Defaults to `disco`.
+ scheduler_type (str, optional):
+ The type of scheduler. Please choose in ['pndm', 'ddim', 'k-lms'].
+ Defaults to `pndm`.
+ """
+
+ def __init__(self, clip, diffusion_type="disco", scheduler_type="pndm"):
+ super().__init__()
+ self.clip = clip
+ self.diffusion_type = diffusion_type
+
+ if diffusion_type == "disco":
+ self.unet_model = create_unet_model(
+ image_size=512,
+ num_channels=256,
+ num_res_blocks=2,
+ channel_mult="",
+ learn_sigma=True,
+ class_cond=False,
+ attention_resolutions='32, 16, 8',
+ num_heads=4,
+ num_head_channels=64,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=True,
+ dropout=0.0,
+ resblock_updown=True,
+ use_new_attention_order=False,
+ )
+ self.secondary_model = create_secondary_model()
+
+ elif diffusion_type == "stable":
+ del self.clip.vision_model
+ self.vae_model = AutoencoderKL(
+ in_channels=3,
+ out_channels=3,
+ down_block_types=(
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ),
+ up_block_types=(
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ ),
+ block_out_channels=(128, 256, 512, 512),
+ layers_per_block=2,
+ act_fn="silu",
+ latent_channels=4,
+ sample_size=512,
+ )
+ if scheduler_type == "pndm":
+ self.scheduler = PNDMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ skip_prk_steps=True,
+ )
+ elif scheduler_type == "ddim":
+ self.scheduler = DDIMScheduler(beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False)
+ elif scheduler_type == "k-lms":
+ self.scheduler = LMSDiscreteScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear")
+ else:
+ raise ValueError(
+ 'scheduler_type must be in ["pndm", "ddim", "k-lms"]')
+
+ self.unet_model = UNet2DConditionModel(
+ sample_size=64,
+ in_channels=4,
+ out_channels=4,
+ center_input_sample=False,
+ flip_sin_to_cos=True,
+ freq_shift=0,
+ down_block_types=(
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ up_block_types=(
+ "UpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ ),
+ block_out_channels=(320, 640, 1280, 1280),
+ layers_per_block=2,
+ downsample_padding=1,
+ mid_block_scale_factor=1,
+ act_fn="silu",
+ norm_num_groups=32,
+ norm_eps=1e-5,
+ cross_attention_dim=768,
+ attention_head_dim=8,
+ )
+ # input_ids_uncond
+ # [49406, 49407, 49407, 49407, 49407,...,49407]
+ input_ids_uncond = [
+ 49406, 49407
+ ] + [49407] * (clip.config["max_text_length"] - 2)
+ self.register_buffer("input_ids_uncond",
+ paddle.to_tensor([input_ids_uncond],
+ dtype="int64"),
+ persistable=False)
+ else:
+ raise ValueError(
+ "diffusion_type: Please choose in ['disco', 'stable']")
+
+ # eval mode and stop all param's gradient
+ self.eval()
+ for param in self.parameters():
+ param.stop_gradient = True
+
+ def generate(self, *args, **kwargs):
+ if self.diffusion_type == "disco":
+ return self.disco_diffusion_generate(*args, **kwargs)
+ else:
+ return self.stable_diffusion_generate(*args, **kwargs)
+
+ def stable_diffusion_generate(self,
+ input_ids,
+ mode="text2image",
+ init_image=None,
+ mask_image=None,
+ seed=None,
+ strength=0.8,
+ height=512,
+ width=512,
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ eta=0.0,
+ latents=None,
+ fp16=True,
+ **kwargs):
+ r"""
+ The CLIPForImageGeneration stable_diffusion_generate method.
+
+ Args:
+ input_ids (Tensor):
+ See :class:`CLIPModel`.
+ mode (str, optional):
+ The mode of StableDiffusion. Support ["text2image", "image2image", "inpaint"].
+ init_image (Union[str, PIL.Image.Image], optional):
+ In `"image2image"` or `"inpaint"` mode, we must input the `init_image`.
+ Used in `"image2image"` and `"inpaint"` mode.
+ Default to `None`.
+ mask_image (Union[str, PIL.Image], optional):
+ In `"inpaint"` mode, we must input the `mask_image`.
+ Used in `"inpaint"` mode.
+ Default to `None`.
+ seed (int, optional):
+ A random number seed which is used as the basis for determining the initial.
+ Default to `None`.
+ strength (float, optional):
+ strength is a value between 0.0 and 1.0, that controls the amount of noise that is
+ added to the input image. Values that approach 1.0 allow for lots of variations but
+ will also produce images that are not semantically consistent with the input.
+ Used in `"image2image"` and `"inpaint"` mode.
+ Default to `0.8`.
+ height (int, optional):
+ The height of the image you want to generate, in pixels.
+ `height` have to be divisible by 64. Used in `"text2image"` mode.
+ Default to `512`.
+ width (int, optional):
+ The height of the image you want to generate, in pixels.
+ `width` have to be divisible by 64. Used in `"text2image"` mode.
+ Default to `512`.
+ num_inference_steps (int, optional):
+ Indicates the number of steps of inference. Generally speaking, the more steps,
+ the better the result, but the more steps, the longer the generation time.
+ Stable Diffusion gives good results with relatively few steps, so the default
+ value is set to 50. If you want faster speed, you can reduce this value,
+ if you want to generate better results, you can increase this value.
+ Default to `50`.
+ guidance_scale (float, optional):
+ `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ corresponds to doing no classifier free guidance.
+ Default to `7.5`.
+ eta (float, optional):
+ eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502.
+ Default to `0.0`.
+ latents (paddle.Tensor, optional):
+ We can specify the latents. latents_shape should be `[batch_size, unet_model.in_channels, height // 8, width // 8]`
+ Default to `None`.
+ fp16 (bool, optional):
+ Whether to use fp16 for inference.
+ Default to `True`.
+
+ """
+ assert input_ids is not None, "text2image/image2image/inpaint, please specify `input_ids`"
+ if mode == "text2image":
+ return self.stable_diffusion_text2image(
+ input_ids=input_ids,
+ seed=seed,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ eta=eta,
+ latents=latents,
+ fp16=fp16)
+ elif mode == "image2image":
+ assert init_image is not None, "image2image mode, please specify `init_image`"
+ # preprocess image
+ init_image = self.preprocess_image(init_image)
+ return self.stable_diffusion_image2image(
+ input_ids=input_ids,
+ init_image=init_image,
+ strength=strength,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ eta=eta,
+ seed=seed,
+ fp16=fp16)
+
+ elif mode == "inpaint":
+ assert init_image is not None, "inpaint mode, please specify `init_image`"
+ assert mask_image is not None, "inpaint mode, please specify `mask_image`"
+ # preprocess image
+ init_image = self.preprocess_image(init_image)
+ # preprocess mask
+ mask_image = self.preprocess_mask(mask_image)
+
+ return self.stable_diffusion_inpainting(
+ input_ids=input_ids,
+ init_image=init_image,
+ mask_image=mask_image,
+ strength=strength,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ eta=eta,
+ seed=seed,
+ fp16=fp16)
+ else:
+ raise ValueError(
+ 'Mode must be in ["text2image", "image2image", "inpaint"]')
+
+ def disco_diffusion_generate(
+ self,
+ input_ids,
+ attention_mask=None,
+ position_ids=None,
+ init_image=None,
+ output_dir='disco_diffusion_clip_vitb32_out/',
+ width_height=[1280, 768],
+ skip_steps=0,
+ steps=250,
+ cut_ic_pow=1,
+ init_scale=1000,
+ clip_guidance_scale=5000,
+ tv_scale=0,
+ range_scale=0,
+ sat_scale=0,
+ cutn_batches=4,
+ perlin_init=False,
+ perlin_mode='mixed',
+ seed=None,
+ eta=0.8,
+ clamp_grad=True,
+ clamp_max=0.05,
+ cut_overview='[12]*400+[4]*600',
+ cut_innercut='[4]*400+[12]*600',
+ cut_icgray_p='[0.2]*400+[0]*600',
+ save_rate=10,
+ n_batches=1,
+ batch_name="",
+ use_secondary_model=True,
+ randomize_class=True,
+ clip_denoised=False,
+ ):
+ r"""
+ The CLIPForImageGeneration disco_diffusion_generate method.
+
+ Args:
+ input_ids (Tensor):
+ See :class:`CLIPModel`.
+ attention_mask (Tensor, optional):
+ See :class:`CLIPModel`.
+ position_ids (Tensor, optional):
+ See :class:`CLIPModel`.
+ init_image (Path, optional):
+ Recall that in the image sequence above, the first image shown is just noise. If an init_image
+ is provided, diffusion will replace the noise with the init_image as its starting state. To use
+ an init_image, upload the image to the Colab instance or your Google Drive, and enter the full
+ image path here. If using an init_image, you may need to increase skip_steps to ~ 50% of total
+ steps to retain the character of the init. See skip_steps above for further discussion.
+ Default to `None`.
+ output_dir (Path, optional):
+ Output directory.
+ Default to `disco_diffusion_clip_vitb32_out`.
+ width_height (List[int, int], optional):
+ Desired final image size, in pixels. You can have a square, wide, or tall image, but each edge
+ length should be set to a multiple of 64px, and a minimum of 512px on the default CLIP model setting.
+ If you forget to use multiples of 64px in your dimensions, DD will adjust the dimensions of your
+ image to make it so.
+ Default to `[1280, 768]`.
+ skip_steps (int, optional):
+ Consider the chart shown here. Noise scheduling (denoise strength) starts very high and progressively
+ gets lower and lower as diffusion steps progress. The noise levels in the first few steps are very high,
+ so images change dramatically in early steps.As DD moves along the curve, noise levels (and thus the
+ amount an image changes per step) declines, and image coherence from one step to the next increases.
+ The first few steps of denoising are often so dramatic that some steps (maybe 10-15% of total) can be
+ skipped without affecting the final image. You can experiment with this as a way to cut render times.
+ If you skip too many steps, however, the remaining noise may not be high enough to generate new content,
+ and thus may not have time left to finish an image satisfactorily.Also, depending on your other settings,
+ you may need to skip steps to prevent CLIP from overshooting your goal, resulting in blown out colors
+ (hyper saturated, solid white, or solid black regions) or otherwise poor image quality. Consider that
+ the denoising process is at its strongest in the early steps, so skipping steps can sometimes mitigate
+ other problems.Lastly, if using an init_image, you will need to skip ~50% of the diffusion steps to retain
+ the shapes in the original init image. However, if you're using an init_image, you can also adjust
+ skip_steps up or down for creative reasons. With low skip_steps you can get a result "inspired by"
+ the init_image which will retain the colors and rough layout and shapes but look quite different.
+ With high skip_steps you can preserve most of the init_image contents and just do fine tuning of the texture.
+ Default to `10`.
+ steps:
+ When creating an image, the denoising curve is subdivided into steps for processing. Each step (or iteration)
+ involves the AI looking at subsets of the image called 'cuts' and calculating the 'direction' the image
+ should be guided to be more like the prompt. Then it adjusts the image with the help of the diffusion denoiser,
+ and moves to the next step.Increasing steps will provide more opportunities for the AI to adjust the image,
+ and each adjustment will be smaller, and thus will yield a more precise, detailed image. Increasing steps
+ comes at the expense of longer render times. Also, while increasing steps should generally increase image
+ quality, there is a diminishing return on additional steps beyond 250 - 500 steps. However, some intricate
+ images can take 1000, 2000, or more steps. It is really up to the user. Just know that the render time is
+ directly related to the number of steps, and many other parameters have a major impact on image quality, without
+ costing additional time.
+ cut_ic_pow (int, optional):
+ This sets the size of the border used for inner cuts. High cut_ic_pow values have larger borders, and
+ therefore the cuts themselves will be smaller and provide finer details. If you have too many or too-small
+ inner cuts, you may lose overall image coherency and/or it may cause an undesirable 'mosaic' effect.
+ Low cut_ic_pow values will allow the inner cuts to be larger, helping image coherency while still helping
+ with some details.
+ Default to `1`.
+ init_scale (int, optional):
+ This controls how strongly CLIP will try to match the init_image provided. This is balanced against the
+ clip_guidance_scale (CGS) above. Too much init scale, and the image won't change much during diffusion.
+ Too much CGS and the init image will be lost.
+ Default to `1000`.
+ clip_guidance_scale (int, optional):
+ CGS is one of the most important parameters you will use. It tells DD how strongly you want CLIP to move
+ toward your prompt each timestep. Higher is generally better, but if CGS is too strong it will overshoot
+ the goal and distort the image. So a happy medium is needed, and it takes experience to learn how to adjust
+ CGS. Note that this parameter generally scales with image dimensions. In other words, if you increase your
+ total dimensions by 50% (e.g. a change from 512 x 512 to 512 x 768), then to maintain the same effect on the
+ image, you'd want to increase clip_guidance_scale from 5000 to 7500. Of the basic settings, clip_guidance_scale,
+ steps and skip_steps are the most important contributors to image quality, so learn them well.
+ Default to `5000`.
+ tv_scale (int, optional):
+ Total variance denoising. Optional, set to zero to turn off. Controls smoothness of final output. If used,
+ tv_scale will try to smooth out your final image to reduce overall noise. If your image is too 'crunchy',
+ increase tv_scale. TV denoising is good at preserving edges while smoothing away noise in flat regions.
+ See https://en.wikipedia.org/wiki/Total_variation_denoising
+ Default to `0`.
+ range_scale (int, optional):
+ Optional, set to zero to turn off. Used for adjustment of color contrast. Lower range_scale will increase
+ contrast. Very low numbers create a reduced color palette, resulting in more vibrant or poster-like images.
+ Higher range_scale will reduce contrast, for more muted images.
+ Default to `0`.
+ sat_scale (int, optional):
+ Saturation scale. Optional, set to zero to turn off. If used, sat_scale will help mitigate oversaturation.
+ If your image is too saturated, increase sat_scale to reduce the saturation.
+ Default to `0`.
+ cutn_batches (int, optional):
+ Each iteration, the AI cuts the image into smaller pieces known as cuts, and compares each cut to the prompt
+ to decide how to guide the next diffusion step. More cuts can generally lead to better images, since DD has
+ more chances to fine-tune the image precision in each timestep. Additional cuts are memory intensive, however,
+ and if DD tries to evaluate too many cuts at once, it can run out of memory. You can use cutn_batches to increase
+ cuts per timestep without increasing memory usage. At the default settings, DD is scheduled to do 16 cuts per
+ timestep. If cutn_batches is set to 1, there will indeed only be 16 cuts total per timestep. However, if
+ cutn_batches is increased to 4, DD will do 64 cuts total in each timestep, divided into 4 sequential batches
+ of 16 cuts each. Because the cuts are being evaluated only 16 at a time, DD uses the memory required for only 16 cuts,
+ but gives you the quality benefit of 64 cuts. The tradeoff, of course, is that this will take ~4 times as long to
+ render each image.So, (scheduled cuts) x (cutn_batches) = (total cuts per timestep). Increasing cutn_batches will
+ increase render times, however, as the work is being done sequentially. DD's default cut schedule is a good place
+ to start, but the cut schedule can be adjusted in the Cutn Scheduling section, explained below.
+ Default to `4`.
+ perlin_init (bool, optional):
+ Normally, DD will use an image filled with random noise as a starting point for the diffusion curve.
+ If perlin_init is selected, DD will instead use a Perlin noise model as an initial state. Perlin has very
+ interesting characteristics, distinct from random noise, so it's worth experimenting with this for your projects.
+ Beyond perlin, you can, of course, generate your own noise images (such as with GIMP, etc) and use them as an
+ init_image (without skipping steps). Choosing perlin_init does not affect the actual diffusion process, just the
+ starting point for the diffusion. Please note that selecting a perlin_init will replace and override any init_image
+ you may have specified. Further, because the 2D, 3D and video animation systems all rely on the init_image system,
+ if you enable Perlin while using animation modes, the perlin_init will jump in front of any previous image or video
+ input, and DD will NOT give you the expected sequence of coherent images. All of that said, using Perlin and
+ animation modes together do make a very colorful rainbow effect, which can be used creatively.
+ Default to `False`.
+ perlin_mode (str, optional):
+ sets type of Perlin noise: colored, gray, or a mix of both, giving you additional options for noise types. Experiment
+ to see what these do in your projects.
+ Default to `mixed`.
+ seed (int, optional):
+ Deep in the diffusion code, there is a random number seed which is used as the basis for determining the initial
+ state of the diffusion. By default, this is random, but you can also specify your own seed. This is useful if you like a
+ particular result and would like to run more iterations that will be similar. After each run, the actual seed value used will be
+ reported in the parameters report, and can be reused if desired by entering seed # here. If a specific numerical seed is used
+ repeatedly, the resulting images will be quite similar but not identical.
+ Default to `None`.
+ eta (float, optional):
+ Eta (greek letter η) is a diffusion model variable that mixes in a random amount of scaled noise into each timestep.
+ 0 is no noise, 1.0 is more noise. As with most DD parameters, you can go below zero for eta, but it may give you
+ unpredictable results. The steps parameter has a close relationship with the eta parameter. If you set eta to 0,
+ then you can get decent output with only 50-75 steps. Setting eta to 1.0 favors higher step counts, ideally around
+ 250 and up. eta has a subtle, unpredictable effect on image, so you'll need to experiment to see how this affects your projects.
+ Default to `0.8`.
+ clamp_grad (bool, optional):
+ As I understand it, clamp_grad is an internal limiter that stops DD from producing extreme results. Try your images with and without
+ clamp_grad. If the image changes drastically with clamp_grad turned off, it probably means your clip_guidance_scale is too high and
+ should be reduced.
+ Default to `True`.
+ clamp_max (float, optional):
+ Sets the value of the clamp_grad limitation. Default is 0.05, providing for smoother, more muted coloration in images, but setting
+ higher values (0.15-0.3) can provide interesting contrast and vibrancy.
+ Default to `0.05`.
+ cut_overview (str, optional):
+ The schedule of overview cuts.
+ Default to `'[12]*400+[4]*600'`.
+ cut_innercut (str, optional):
+ The schedule of inner cuts.
+ Default to `'[4]*400+[12]*600'`.
+ cut_icgray_p (str, optional):
+ This sets the size of the border used for inner cuts. High cut_ic_pow values have larger borders, and therefore the cuts
+ themselves will be smaller and provide finer details. If you have too many or too-small inner cuts, you may lose overall
+ image coherency and/or it may cause an undesirable 'mosaic' effect. Low cut_ic_pow values will allow the inner cuts to be
+ larger, helping image coherency while still helping with some details.
+ Default to `'[0.2]*400+[0]*600'`.
+ save_rate (int, optional):
+ During a diffusion run, you can monitor the progress of each image being created with this variable. If display_rate is set
+ to 50, DD will show you the in-progress image every 50 timesteps. Setting this to a lower value, like 5 or 10, is a good way
+ to get an early peek at where your image is heading. If you don't like the progression, just interrupt execution, change some
+ settings, and re-run. If you are planning a long, unmonitored batch, it's better to set display_rate equal to steps, because
+ displaying interim images does slow Colab down slightly.
+ Default to `10`.
+ n_batches (int, optional):
+ This variable sets the number of still images you want DD to create. If you are using an animation mode (see below for details)
+ DD will ignore n_batches and create a single set of animated frames based on the animation settings.
+ Default to `1`.
+ batch_name (str, optional):
+ The name of the batch, the batch id will be named as "progress-[batch_name]-seed-[range(n_batches)]-[save_rate]". To avoid your
+ artworks be overridden by other users, please use a unique name.
+ Default to `''`.
+ use_secondary_model (bool, optional):
+ Whether or not use secondary model.
+ Default to `True`.
+ randomize_class (bool, optional):
+ Random class.
+ Default to `True`.
+ clip_denoised (bool, optional):
+ Clip denoised.
+ Default to `False`.
+
+ Returns:
+ List[PIL.Image]: Returns n_batches of final image.
+ Its data type should be PIL.Image.
+
+ Example:
+ .. code-block::
+
+ from paddlenlp.transformers import CLIPForImageGeneration, CLIPTokenizer
+
+ # Initialize the model and tokenizer
+ model_name_or_path = 'openai/disco-diffusion-clip-vit-base-patch32'
+ model = CLIPForImageGeneration.from_pretrained(model_name_or_path)
+ tokenizer = CLIPTokenizer.from_pretrained(model_name_or_path)
+ model.eval()
+
+ # Prepare the text_prompt.
+ text_prompt = "A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation."
+ # Style and artist can be specified
+ style = None
+ artist = None
+ text_prompt = model.preprocess_text_prompt(text_prompt, style=style, artist=artist)
+
+ # CLIP's pad_token_id is 0 and padding to max length 77 (tokenizer.model_max_length).
+ raw_pad_token_id = tokenizer.pad_token_id
+ tokenizer.pad_token_id = 0
+ tokenized_inputs = tokenizer(text_prompt, return_tensors="pd", padding="max_length", max_length=tokenizer.model_max_length)
+ images = model.generate(**tokenized_inputs)
+
+ # return List[PIL.Image]
+ images[0].save("figure.png")
+
+ """
+ self.diffusion = create_gaussian_diffusion(
+ steps=steps,
+ learn_sigma=True,
+ sigma_small=False,
+ noise_schedule="linear",
+ predict_xstart=False,
+ rescale_timesteps=True,
+ )
+ # get get_text_features
+ target_text_embeds = self.get_text_features(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids)
+
+ images_list = super().disco_diffusion_generate(
+ target_text_embeds=target_text_embeds,
+ clamp_grad=clamp_grad,
+ clamp_max=clamp_max,
+ clip_denoised=clip_denoised,
+ clip_guidance_scale=clip_guidance_scale,
+ cut_ic_pow=cut_ic_pow,
+ cut_icgray_p=cut_icgray_p,
+ cut_innercut=cut_innercut,
+ cut_overview=cut_overview,
+ cutn_batches=cutn_batches,
+ save_rate=save_rate,
+ eta=eta,
+ init_image=init_image,
+ init_scale=init_scale,
+ n_batches=n_batches,
+ output_dir=output_dir,
+ perlin_init=perlin_init,
+ perlin_mode=perlin_mode,
+ randomize_class=randomize_class,
+ range_scale=range_scale,
+ sat_scale=sat_scale,
+ seed=seed,
+ skip_steps=skip_steps,
+ tv_scale=tv_scale,
+ use_secondary_model=use_secondary_model,
+ width_height=width_height,
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
+ image_std=[0.26862954, 0.26130258, 0.27577711],
+ batch_name=batch_name)
+
+ return images_list
+
+ def preprocess_text_prompt(self, text_prompt, style=None, artist=None):
+ text_prompt = text_prompt.rstrip(',.,。')
+ if style is not None:
+ text_prompt += ",{}".format(style)
+ if artist is not None:
+ text_prompt += ",{},trending on artstation".format(artist)
+ return text_prompt
+
+ def __getattr__(self, name):
+ try:
+ return super().__getattr__(name)
+ except AttributeError as e:
+ try:
+ return getattr(getattr(self, self.base_model_prefix), name)
+ except AttributeError:
+ try:
+ return getattr(self, self.base_model_prefix).config[name]
+ except KeyError:
+ raise e
diff --git a/paddlenlp/transformers/clip/procesing.py b/paddlenlp/transformers/clip/procesing.py
new file mode 100644
index 000000000000..2679b642b801
--- /dev/null
+++ b/paddlenlp/transformers/clip/procesing.py
@@ -0,0 +1,119 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Image/Text processor class for CLIP
+"""
+
+from ..tokenizer_utils_base import BatchEncoding
+from .tokenizer import CLIPTokenizer
+from .feature_extraction import CLIPFeatureExtractor
+
+__all__ = ["CLIPProcessor"]
+
+
+class CLIPProcessor(object):
+ r"""
+ Constructs a CLIP processor which wraps a CLIP feature extractor and a CLIP tokenizer into a single processor.
+ [`CLIPProcessor`] offers all the functionalities of [`CLIPFeatureExtractor`] and [`CLIPTokenizer`]. See the
+ [`CLIPProcessor.__call__`] and [`CLIPProcessor.decode`] for more information.
+ Args:
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ The feature extractor is a required input.
+ tokenizer ([`CLIPTokenizer`]):
+ The tokenizer is a required input.
+ """
+
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.feature_extractor = feature_extractor
+
+ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to CLIPTokenizer's [`CLIPTokenizer.__call__`] if `text` is not `None` to encode
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
+ CLIPFeatureExtractor's [`CLIPFeatureExtractor.__call__`] if `images` is not `None`. Please refer to the
+ doctsring of the above two methods for more information.
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ images (`PIL.Image.Image`, `np.ndarray`, `paddle.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[paddle.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or Paddle
+ tensor. In case of a NumPy array/Paddle tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'pd'`: Return Paddle `paddle.Tensor` objects.
+ Returns:
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+
+ if text is None and images is None:
+ raise ValueError(
+ "You have to specify either text or images. Both cannot be none."
+ )
+
+ if text is not None:
+ encoding = self.tokenizer(text,
+ return_tensors=return_tensors,
+ **kwargs)
+
+ if images is not None:
+ image_features = self.feature_extractor(
+ images, return_tensors=return_tensors, **kwargs)
+
+ if text is not None and images is not None:
+ encoding["pixel_values"] = image_features.pixel_values
+ return encoding
+ elif text is not None:
+ return encoding
+ else:
+ return BatchEncoding(data=dict(**image_features),
+ tensor_type=return_tensors)
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to CLIPTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to CLIPTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ # TODO junnyu find a better way from_pretrained and save_pretrained
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path,
+ *args, **kwargs)
+ feature_extractor = CLIPFeatureExtractor()
+ return cls(feature_extractor, tokenizer)
+
+ def save_pretrained(self, save_directory, filename_prefix=None, **kwargs):
+ return self.tokenizer.save_pretrained(save_directory, filename_prefix,
+ **kwargs)
diff --git a/paddlenlp/transformers/clip/tokenizer.py b/paddlenlp/transformers/clip/tokenizer.py
new file mode 100644
index 000000000000..4e5d384dc9d3
--- /dev/null
+++ b/paddlenlp/transformers/clip/tokenizer.py
@@ -0,0 +1,419 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2021 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddle.utils import try_import
+import os
+import shutil
+from .. import PretrainedTokenizer, AddedToken, BasicTokenizer
+from ...utils.log import logger
+from functools import lru_cache
+import json
+
+__all__ = ['CLIPTokenizer']
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
+ characters the bpe code barfs on.
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
+ tables between utf-8 bytes and unicode strings.
+ """
+ bs = (list(range(ord("!"),
+ ord("~") + 1)) + list(range(ord("¡"),
+ ord("¬") + 1)) +
+ list(range(ord("®"),
+ ord("ÿ") + 1)))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def whitespace_clean(text, re):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def get_pairs(word):
+ """
+ Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+class CLIPTokenizer(PretrainedTokenizer):
+ r"""
+ Construct a CLIP tokenizer based on byte-level Byte-Pair-Encoding.
+
+ This tokenizer inherits from :class:`~paddlenlp.transformers.gpt.tokenizer.GPTTokenizer`.
+ For more information regarding those methods, please refer to this superclass.
+
+ Args:
+ vocab_file (str):
+ Path to the vocabulary file.
+ The vocab file contains a mapping from vocabulary strings to indices.
+ merges_file (str):
+ Path to the merge file.
+ The merge file is used to split the input sentence into "subword" units.
+ The vocab file is then used to encode those units as intices.
+ errors (str):
+ Paradigm to follow when decoding bytes to UTF-8.
+ Defaults to `'replace'`.
+ max_len (int, optional):
+ The maximum value of the input sequence length.
+ Defaults to `77`.
+ bos_token (str, optional):
+ The beginning of sequence token that was used during pretraining. Can be
+ used a sequence classifier token.
+ Defaults to `"<|startoftext|>"`.
+ eos_token (str, optional):
+ A special token representing the end of a sequence that was used during pretraining.
+ Defaults to `"<|endoftext|>"`.
+ unk_token (str, optional):
+ A special token representing the *unknown (out-of-vocabulary)* token.
+ An unknown token is set to be `unk_token` inorder to be converted to an ID.
+ Defaults to `"<|endoftext|>"`.
+ pad_token (str, optional):
+ A special token used to make arrays of tokens the same size for batching purposes.
+ Defaults to `"<|endoftext|>"`.
+
+ Examples:
+ .. code-block::
+
+ from paddlenlp.transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32')
+ print(tokenizer('He was a puppeteer'))
+
+ '''
+ {'input_ids': [49406, 797, 739, 320, 7116, 38820, 528, 49407]}
+ '''
+
+ """
+ # merges and vocab same as GPT2
+ resource_files_names = {
+ "vocab_file": "vocab.json",
+ "merges_file": "merges.txt"
+ }
+ pretrained_resource_files_map = {
+ "vocab_file": {
+ "openai/clip-vit-base-patch32":
+ "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-vit-base-patch32/vocab.json",
+ "openai/clip-rn50":
+ "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-rn50/vocab.json",
+ "openai/clip-rn101":
+ "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-rn101/vocab.json",
+ "openai/clip-vit-large-patch14":
+ "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-vit-large-patch14/vocab.json",
+ },
+ "merges_file": {
+ "openai/clip-vit-base-patch32":
+ "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-vit-base-patch32/merges.txt",
+ "openai/clip-rn50":
+ "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-rn50/merges.txt",
+ "openai/clip-rn101":
+ "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-rn101/merges.txt",
+ "openai/clip-vit-large-patch14":
+ "http://bj.bcebos.com/paddlenlp/models/community/openai/clip-vit-large-patch14/merges.txt",
+ }
+ }
+ pretrained_init_configuration = {
+ "openai/clip-vit-base-patch32": {
+ "max_len": 77
+ },
+ "openai/clip-rn50": {
+ "max_len": 77
+ },
+ "openai/clip-rn101": {
+ "max_len": 77
+ },
+ "openai/clip-vit-large-patch14": {
+ "max_len": 77
+ },
+ }
+
+ def __init__(self,
+ vocab_file,
+ merges_file,
+ errors='replace',
+ max_len=77,
+ bos_token="<|startoftext|>",
+ eos_token="<|endoftext|>",
+ unk_token="<|endoftext|>",
+ pad_token="<|endoftext|>",
+ **kwargs):
+
+ bos_token = AddedToken(bos_token,
+ lstrip=False, rstrip=False) if isinstance(
+ bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token,
+ lstrip=False, rstrip=False) if isinstance(
+ eos_token, str) else eos_token
+ unk_token = AddedToken(unk_token,
+ lstrip=False, rstrip=False) if isinstance(
+ unk_token, str) else unk_token
+ pad_token = AddedToken(pad_token,
+ lstrip=False, rstrip=False) if isinstance(
+ pad_token, str) else pad_token
+
+ self._build_special_tokens_map_extended(bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ pad_token=pad_token)
+
+ try:
+ import ftfy
+ self.fix_text = ftfy.fix_text
+ except ImportError:
+ logger.warning(
+ "ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy."
+ )
+ self.nlp = BasicTokenizer(do_lower_case=True)
+ self.fix_text = None
+ self.re = try_import("regex")
+
+ self._vocab_file = vocab_file
+ self._merges_file = merges_file
+ self.max_len = max_len if max_len is not None else int(1e12)
+
+ with open(vocab_file, 'r', encoding='utf-8') as f:
+ self.encoder = json.load(f)
+
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ bpe_merges = merges_handle.read().strip().split("\n")[1:49152 -
+ 256 - 2 + 1]
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ self.cache = {
+ "<|startoftext|>": "<|startoftext|>",
+ "<|endoftext|>": "<|endoftext|>"
+ }
+
+ self.pat = self.re.compile(
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+ self.re.IGNORECASE,
+ )
+
+ @property
+ def vocab_size(self):
+ """
+ Returns the size of vocabulary.
+
+ Returns:
+ int: The sum of size of vocabulary and the size of speical tokens.
+
+ """
+ return len(self.encoder)
+
+ def get_vocab(self):
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
+ by concatenating and adding special tokens. A CLIP sequence has the following format:
+
+ - single sequence: `<|startoftext|> X <|endoftext|>`
+
+ Pairs of sequences are not the expected use case, but they will be handled without a separator.
+
+ Args:
+ token_ids_0 (List[int]):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (List[int], optional):
+ Optional second list of IDs for sequence pairs. Defaults to None.
+
+ Returns:
+ List[int]: List of input_id with the appropriate special tokens.
+ """
+ _bos = [self.bos_token_id]
+ _eos = [self.eos_token_id]
+ if token_ids_1 is None:
+ return _bos + token_ids_0 + _eos
+ return _bos + token_ids_0 + _eos + _eos + token_ids_1 + _eos
+
+ def get_special_tokens_mask(self,
+ token_ids_0,
+ token_ids_1=None,
+ already_has_special_tokens=False):
+ """
+ Retrieves sequence ids from a token list that has no special tokens added. This method is
+ called when adding special tokens using the tokenizer ``encode`` methods.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0,
+ token_ids_1=token_ids_1,
+ already_has_special_tokens=True)
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + (
+ [0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(self,
+ token_ids_0,
+ token_ids_1=None):
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task.
+ """
+ eos = [self.eos_token_id]
+ bos = [self.bos_token_id]
+
+ if token_ids_1 is None:
+ return len(bos + token_ids_0 + eos) * [0]
+ return len(bos + token_ids_0 + eos + eos + token_ids_1 + eos) * [0]
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + (token[-1] + "", )
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ""
+
+ while True:
+ bigram = min(
+ pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i +
+ 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def _tokenize(self, text):
+ """Tokenize a string."""
+ bpe_tokens = []
+ if self.fix_text is None:
+ text = " ".join(self.nlp.tokenize(text))
+ else:
+ text = whitespace_clean(self.fix_text(text), self.re).lower()
+
+ for token in self.re.findall(self.pat, text):
+ token = "".join(
+ self.byte_encoder[b] for b in token.encode("utf-8")
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
+ bpe_tokens.extend(bpe_token
+ for bpe_token in self.bpe(token).split(" "))
+ return bpe_tokens
+
+ def _convert_token_to_id(self, token):
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ return self.decoder[index]
+
+ def convert_tokens_to_string(self, tokens):
+ """
+ Converts a sequence of tokens (string) in a single string.
+ """
+ text = "".join(tokens)
+ text = bytearray([self.byte_decoder[c] for c in text
+ ]).decode('utf-8',
+ errors=self.errors).replace("",
+ " ").strip()
+ return text
+
+ def save_resources(self, save_directory):
+ """
+ Saves `SentencePiece `__ file
+ (ends with '.spm') under `save_directory`.
+
+ Args:
+ save_directory (str): Directory to save files into.
+ """
+ for name, file_name in self.resource_files_names.items():
+ source_path = getattr(self, "_%s" % name)
+
+ save_path = os.path.join(save_directory, file_name)
+ if os.path.abspath(source_path) != os.path.abspath(save_path):
+ shutil.copyfile(source_path, save_path)
+
+ def __call__(
+ self,
+ text,
+ text_pair=None,
+ max_length=None,
+ stride=0,
+ is_split_into_words=False,
+ padding=False,
+ truncation=False,
+ return_position_ids=False,
+ return_token_type_ids=False, # don't return token_type_ids
+ return_attention_mask=False,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_special_tokens_mask=False,
+ return_dict=True,
+ return_offsets_mapping=False,
+ add_special_tokens=True,
+ pad_to_multiple_of=None,
+ return_tensors=None,
+ verbose: bool = True,
+ **kwargs):
+ return super().__call__(
+ text, text_pair, max_length, stride, is_split_into_words, padding,
+ truncation, return_position_ids, return_token_type_ids,
+ return_attention_mask, return_length, return_overflowing_tokens,
+ return_special_tokens_mask, return_dict, return_offsets_mapping,
+ add_special_tokens, pad_to_multiple_of, return_tensors, verbose,
+ **kwargs)
diff --git a/paddlenlp/transformers/feature_extraction_utils.py b/paddlenlp/transformers/feature_extraction_utils.py
new file mode 100644
index 000000000000..5e612ab3e19f
--- /dev/null
+++ b/paddlenlp/transformers/feature_extraction_utils.py
@@ -0,0 +1,121 @@
+# coding=utf-8
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+from collections import UserDict
+from typing import Any, Dict, Optional, Union
+
+import numpy as np
+from .tokenizer_utils_base import TensorType
+
+
+class BatchFeature(UserDict):
+ r"""
+ Holds the feature extractor specific `__call__` methods.
+ This class is derived from a python dictionary and can be used as a dictionary.
+ Args:
+ data (`dict`):
+ Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask',
+ etc.).
+ tensor_type (`Union[None, str, TensorType]`, *optional*):
+ You can give a tensor_type here to convert the lists of integers in Paddle/Numpy Tensors at
+ initialization.
+ """
+
+ def __init__(self,
+ data: Optional[Dict[str, Any]] = None,
+ tensor_type: Union[None, str, TensorType] = None):
+ super().__init__(data)
+ self.convert_to_tensors(tensor_type=tensor_type)
+
+ def __getitem__(self, item: str):
+ """
+ If the key is a string, returns the value of the dict associated to `key` ('input_values', 'attention_mask',
+ etc.).
+ """
+ if isinstance(item, str):
+ return self.data[item]
+ else:
+ raise KeyError(
+ "Indexing with integers is not available when using Python based feature extractors"
+ )
+
+ def __getattr__(self, item: str):
+ try:
+ return self.data[item]
+ except KeyError:
+ raise AttributeError
+
+ def __getstate__(self):
+ return {"data": self.data}
+
+ def __setstate__(self, state):
+ if "data" in state:
+ self.data = state["data"]
+
+ def keys(self):
+ return self.data.keys()
+
+ def values(self):
+ return self.data.values()
+
+ def items(self):
+ return self.data.items()
+
+ def convert_to_tensors(self,
+ tensor_type: Optional[Union[str,
+ TensorType]] = None):
+ """
+ Convert the inner content to tensors.
+ Args:
+ tensor_type (`str` or [`TensorType`], *optional*):
+ The type of tensors to use. If `str`, should be one of the values of the enum [`TensorType`]. If
+ `None`, no modification is done.
+ """
+ if tensor_type is None:
+ return self
+
+ # Convert to TensorType
+ if not isinstance(tensor_type, TensorType):
+ tensor_type = TensorType(tensor_type)
+
+ # Get a function reference for the correct framework
+ if tensor_type == TensorType.PADDLE:
+ as_tensor = paddle.to_tensor
+ is_tensor = paddle.is_tensor
+ else:
+ as_tensor = np.asarray
+ is_tensor = lambda x: isinstance(x, np.ndarray)
+
+ # Do the tensor conversion in batch
+ for key, value in self.items():
+ try:
+ if not is_tensor(value):
+ tensor = as_tensor(value)
+
+ self[key] = tensor
+ except: # noqa E722
+ if key == "overflowing_tokens":
+ raise ValueError(
+ "Unable to create tensor returning overflowing tokens of different lengths. "
+ "Please see if a fast version of this tokenizer is available to have this feature available."
+ )
+ raise ValueError(
+ "Unable to create tensor, you should probably activate truncation and/or padding "
+ "with 'padding=True' 'truncation=True' to have batched tensors with the same length."
+ )
+
+ return self
diff --git a/paddlenlp/transformers/guided_diffusion_utils/__init__.py b/paddlenlp/transformers/guided_diffusion_utils/__init__.py
new file mode 100644
index 000000000000..46c4d89f4cd9
--- /dev/null
+++ b/paddlenlp/transformers/guided_diffusion_utils/__init__.py
@@ -0,0 +1,2 @@
+from .model_diffusion import create_gaussian_diffusion, create_unet_model, create_secondary_model
+from .utils import DiscoDiffusionMixin
\ No newline at end of file
diff --git a/paddlenlp/transformers/guided_diffusion_utils/gaussian_diffusion.py b/paddlenlp/transformers/guided_diffusion_utils/gaussian_diffusion.py
new file mode 100755
index 000000000000..7736c42a6cef
--- /dev/null
+++ b/paddlenlp/transformers/guided_diffusion_utils/gaussian_diffusion.py
@@ -0,0 +1,771 @@
+"""
+Diffusion model implemented by Paddle.
+This code is rewritten based on Pytorch version of of Ho et al's diffusion models:
+https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
+"""
+import enum
+import math
+
+import numpy as np
+import paddle
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ Get a pre-defined beta schedule for the given name.
+
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == "linear":
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ beta_start = scale * 0.0001
+ beta_end = scale * 0.02
+ return np.linspace(beta_start,
+ beta_end,
+ num_diffusion_timesteps,
+ dtype=np.float64)
+ elif schedule_name == "cosine":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2)**2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class GaussianDiffusion:
+ """
+ Utilities for sampling diffusion models.
+
+ Ported directly from here, and then adapted over time to further experimentation.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
+ :param model_var_type: a ModelVarType determining how variance is output.
+ :param rescale_timesteps: if True, pass floating point timesteps into the
+ model so that they are always scaled like in the
+ original paper (0 to 1000).
+ """
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ rescale_timesteps=False,
+ ):
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.rescale_timesteps = rescale_timesteps
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ self.betas = betas
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps, )
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod -
+ 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = (betas * (1.0 - self.alphas_cumprod_prev) /
+ (1.0 - self.alphas_cumprod))
+ # log calculation clipped because the posterior variance is 0 at the
+ # beginning of the diffusion chain.
+ self.posterior_log_variance_clipped = np.log(
+ np.append(self.posterior_variance[1], self.posterior_variance[1:]))
+ self.posterior_mean_coef1 = (betas * np.sqrt(self.alphas_cumprod_prev) /
+ (1.0 - self.alphas_cumprod))
+ self.posterior_mean_coef2 = ((1.0 - self.alphas_cumprod_prev) *
+ np.sqrt(alphas) /
+ (1.0 - self.alphas_cumprod))
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) *
+ x_start)
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t,
+ x_start.shape)
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod,
+ t, x_start.shape)
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+
+ In other words, sample from q(x_t | x_0).
+
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ # noise = th.randn_like(x_start)
+ noise = paddle.randn(x_start.shape, x_start.dtype)
+ assert noise.shape == x_start.shape
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) *
+ x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
+ t, x_start.shape) * noise)
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+
+ q(x_{t-1} | x_t, x_0)
+
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) *
+ x_start +
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t)
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t,
+ x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape)
+ assert (posterior_mean.shape[0] == posterior_variance.shape[0] ==
+ posterior_log_variance_clipped.shape[0] == x_start.shape[0])
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ B, C = x.shape[:2]
+ assert t.shape == [B]
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE
+ ]:
+ assert model_output.shape == [B, C * 2, *x.shape[2:]]
+ model_output, model_var_values = paddle.split(model_output,
+ 2,
+ axis=1)
+ if self.model_var_type == ModelVarType.LEARNED:
+ model_log_variance = model_var_values
+ model_variance = paddle.exp(model_log_variance)
+ else:
+ min_log = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x.shape)
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = paddle.exp(model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ ModelVarType.FIXED_LARGE: (
+ np.append(self.posterior_variance[1], self.betas[1:]),
+ np.log(np.append(self.posterior_variance[1],
+ self.betas[1:])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t,
+ x.shape)
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output))
+ model_mean = model_output
+ elif self.model_mean_type in [
+ ModelMeanType.START_X, ModelMeanType.EPSILON
+ ]:
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
+ model_mean, _, _ = self.q_posterior_mean_variance(
+ x_start=pred_xstart, x_t=x, t=t)
+ else:
+ raise NotImplementedError(self.model_mean_type)
+
+ assert (model_mean.shape == model_log_variance.shape ==
+ pred_xstart.shape == x.shape)
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) *
+ x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t,
+ x_t.shape) * eps)
+
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
+ assert x_t.shape == xprev.shape
+ return ( # (xprev - coef2*x_t) / coef1
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape)
+ * xprev - _extract_into_tensor(
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
+ x_t.shape) * x_t)
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) *
+ x_t - pred_xstart) / _extract_into_tensor(
+ self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _scale_timesteps(self, t):
+ if self.rescale_timesteps:
+ return paddle.cast((t), 'float32') * (1000.0 / self.num_timesteps)
+ return t
+
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
+ new_mean = (paddle.cast((p_mean_var["mean"]), 'float32') +
+ p_mean_var["variance"] * paddle.cast((gradient), 'float32'))
+ return new_mean
+
+ def condition_mean_with_grad(self,
+ cond_fn,
+ p_mean_var,
+ x,
+ t,
+ model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, t, p_mean_var, **model_kwargs)
+ new_mean = (paddle.cast((p_mean_var["mean"]), 'float32') +
+ p_mean_var["variance"] * paddle.cast((gradient), 'float32'))
+ return new_mean
+
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+
+ See condition_mean() for details on cond_fn.
+
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
+ x, self._scale_timesteps(t), **model_kwargs)
+
+ out = p_mean_var.copy()
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
+ out["mean"], _, _ = self.q_posterior_mean_variance(
+ x_start=out["pred_xstart"], x_t=x, t=t)
+ return out
+
+ def condition_score_with_grad(self,
+ cond_fn,
+ p_mean_var,
+ x,
+ t,
+ model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+
+ See condition_mean() for details on cond_fn.
+
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, p_mean_var, **
+ model_kwargs)
+
+ out = p_mean_var.copy()
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
+ out["mean"], _, _ = self.q_posterior_mean_variance(
+ x_start=out["pred_xstart"], x_t=x, t=t)
+ return out
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+
+ """
+ out_orig = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn,
+ out_orig,
+ x,
+ t,
+ model_kwargs=model_kwargs)
+ else:
+ out = out_orig
+
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t,
+ x.shape)
+ sigma = (eta * paddle.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) *
+ paddle.sqrt(1 - alpha_bar / alpha_bar_prev))
+ # Equation 12.
+ # noise = th.randn_like(x)
+ noise = paddle.randn(x.shape, x.dtype)
+ mean_pred = (out["pred_xstart"] * paddle.sqrt(alpha_bar_prev) +
+ paddle.sqrt(1 - alpha_bar_prev - sigma**2) * eps)
+ nonzero_mask = (paddle.cast(
+ (t != 0), 'float32').reshape([-1, *([1] * (len(x.shape) - 1))])
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]}
+
+ def ddim_sample_with_grad(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+
+ """
+ # with th.enable_grad():
+ # x = x.detach().requires_grad_()
+ x = x.detach()
+ # x.stop_gradient = False
+ out_orig = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score_with_grad(cond_fn,
+ out_orig,
+ x,
+ t,
+ model_kwargs=model_kwargs)
+ else:
+ out = out_orig
+
+ out["pred_xstart"] = out["pred_xstart"].detach()
+
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t,
+ x.shape)
+ sigma = (eta * paddle.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) *
+ paddle.sqrt(1 - alpha_bar / alpha_bar_prev))
+ # Equation 12.
+ # noise = th.randn_like(x)
+ noise = paddle.randn(x.shape, x.dtype)
+ mean_pred = (out["pred_xstart"] * paddle.sqrt(alpha_bar_prev) +
+ paddle.sqrt(1 - alpha_bar_prev - sigma**2) * eps)
+ nonzero_mask = (paddle.cast(
+ (t != 0), 'float32').reshape([-1, *([1] * (len(x.shape) - 1))])
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {
+ "sample": sample,
+ "pred_xstart": out_orig["pred_xstart"].detach()
+ }
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ progress=False,
+ eta=0.0,
+ skip_timesteps=0,
+ init_image=None,
+ randomize_class=False,
+ cond_fn_with_grad=False,
+ ):
+ """
+ Generate samples from the model using DDIM.
+
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ progress=progress,
+ eta=eta,
+ skip_timesteps=skip_timesteps,
+ init_image=init_image,
+ randomize_class=randomize_class,
+ cond_fn_with_grad=cond_fn_with_grad,
+ ):
+ final = sample
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ progress=False,
+ eta=0.0,
+ skip_timesteps=0,
+ init_image=None,
+ randomize_class=False,
+ cond_fn_with_grad=False,
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+
+ """
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = paddle.randn(shape)
+
+ if skip_timesteps and init_image is None:
+ init_image = paddle.zeros_like(img)
+
+ indices = list(range(self.num_timesteps - skip_timesteps))[::-1]
+
+ if init_image is not None:
+ my_t = paddle.ones([shape[0]], dtype='int64') * indices[0]
+ img = self.q_sample(init_image, my_t, img)
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = paddle.to_tensor([i] * shape[0])
+ if randomize_class and 'y' in model_kwargs:
+ model_kwargs['y'] = paddle.randint(
+ low=0,
+ high=model.num_classes,
+ shape=model_kwargs['y'].shape,
+ )
+ sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample
+ out = sample_fn(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ )
+ yield out
+ img = out["sample"]
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim"):])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}")
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ A diffusion process which can skip steps in a base diffusion process.
+
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.timestep_map = []
+ self.original_num_steps = len(kwargs["betas"])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ self.timestep_map.append(i)
+ kwargs["betas"] = np.array(new_betas)
+ super().__init__(**kwargs)
+
+ def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
+
+ def condition_mean(self, cond_fn, *args, **kwargs):
+ return super().condition_mean(self._wrap_model(cond_fn), *args,
+ **kwargs)
+
+ def condition_score(self, cond_fn, *args, **kwargs):
+ return super().condition_score(self._wrap_model(cond_fn), *args,
+ **kwargs)
+
+ def _wrap_model(self, model):
+ if isinstance(model, _WrappedModel):
+ return model
+ return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
+ self.original_num_steps)
+
+ def _scale_timesteps(self, t):
+ # Scaling is done by the wrapped model.
+ return t
+
+
+class _WrappedModel:
+
+ def __init__(self, model, timestep_map, rescale_timesteps,
+ original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, ts, **kwargs):
+ map_tensor = paddle.to_tensor(self.timestep_map,
+ place=ts.place,
+ dtype=ts.dtype)
+ new_ts = map_tensor[ts]
+ if self.rescale_timesteps:
+ new_ts = paddle.cast(new_ts,
+ 'float32') * (1000.0 / self.original_num_steps)
+ return self.model(x, new_ts, **kwargs)
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = paddle.to_tensor(arr, place=timesteps.place)[timesteps]
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res.expand(broadcast_shape)
diff --git a/paddlenlp/transformers/guided_diffusion_utils/losses.py b/paddlenlp/transformers/guided_diffusion_utils/losses.py
new file mode 100755
index 000000000000..a43a53f2fca9
--- /dev/null
+++ b/paddlenlp/transformers/guided_diffusion_utils/losses.py
@@ -0,0 +1,26 @@
+"""
+Helpers for various likelihood-based losses implemented by Paddle. These are ported from the original
+Ho et al. diffusion models codebase:
+https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
+"""
+import paddle
+import paddle.nn.functional as F
+
+
+def spherical_dist_loss(x, y):
+ x = F.normalize(x, axis=-1)
+ y = F.normalize(y, axis=-1)
+ return (x - y).norm(axis=-1).divide(
+ paddle.to_tensor(2.0)).asin().pow(2).multiply(paddle.to_tensor(2.0))
+
+
+def tv_loss(input):
+ """L2 total variation loss, as in Mahendran et al."""
+ input = F.pad(input, (0, 1, 0, 1), 'replicate')
+ x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
+ y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
+ return (x_diff**2 + y_diff**2).mean([1, 2, 3])
+
+
+def range_loss(input):
+ return (input - input.clip(-1, 1)).pow(2).mean([1, 2, 3])
diff --git a/paddlenlp/transformers/guided_diffusion_utils/make_cutouts.py b/paddlenlp/transformers/guided_diffusion_utils/make_cutouts.py
new file mode 100755
index 000000000000..9d9831993923
--- /dev/null
+++ b/paddlenlp/transformers/guided_diffusion_utils/make_cutouts.py
@@ -0,0 +1,96 @@
+'''
+This code is rewritten by Paddle based on Jina-ai/discoart.
+https://github.com/jina-ai/discoart/blob/main/discoart/nn/make_cutouts.py
+'''
+import paddle
+import paddle.nn as nn
+from paddle.nn import functional as F
+from .resize_right import resize
+
+from . import transforms as T
+
+skip_augs = False
+padargs = {}
+
+
+class MakeCutoutsDango(nn.Layer):
+
+ def __init__(self,
+ cut_size,
+ Overview=4,
+ InnerCrop=0,
+ IC_Size_Pow=0.5,
+ IC_Grey_P=0.2):
+ super().__init__()
+ self.cut_size = cut_size
+ self.Overview = Overview
+ self.InnerCrop = InnerCrop
+ self.IC_Size_Pow = IC_Size_Pow
+ self.IC_Grey_P = IC_Grey_P
+ self.augs = nn.Sequential(*[
+ T.RandomHorizontalFlip(prob=0.5),
+ T.Lambda(lambda x: x + paddle.randn(x.shape) * 0.01),
+ T.RandomAffine(
+ degrees=10,
+ translate=(0.05, 0.05),
+ interpolation=T.InterpolationMode.BILINEAR,
+ ),
+ T.Lambda(lambda x: x + paddle.randn(x.shape) * 0.01),
+ T.RandomGrayscale(p=0.1),
+ T.Lambda(lambda x: x + paddle.randn(x.shape) * 0.01),
+ T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1,
+ hue=0.1),
+ ])
+
+ def forward(self, input):
+ cutouts = []
+ gray = T.Grayscale(3)
+ sideY, sideX = input.shape[2:4]
+ max_size = min(sideX, sideY)
+ min_size = min(sideX, sideY, self.cut_size)
+ output_shape = [1, 3, self.cut_size, self.cut_size]
+ pad_input = F.pad(
+ input,
+ (
+ (sideY - max_size) // 2,
+ (sideY - max_size) // 2,
+ (sideX - max_size) // 2,
+ (sideX - max_size) // 2,
+ ),
+ **padargs,
+ )
+ cutout = resize(pad_input, out_shape=output_shape)
+
+ if self.Overview > 0:
+ if self.Overview <= 4:
+ if self.Overview >= 1:
+ cutouts.append(cutout)
+ if self.Overview >= 2:
+ cutouts.append(gray(cutout))
+ if self.Overview >= 3:
+ cutouts.append(cutout[:, :, :, ::-1])
+ if self.Overview == 4:
+ cutouts.append(gray(cutout[:, :, :, ::-1]))
+ else:
+ cutout = resize(pad_input, out_shape=output_shape)
+ for _ in range(self.Overview):
+ cutouts.append(cutout)
+
+ if self.InnerCrop > 0:
+ for i in range(self.InnerCrop):
+ size = int(
+ paddle.rand([1])**self.IC_Size_Pow * (max_size - min_size) +
+ min_size)
+ offsetx = paddle.randint(0, sideX - size + 1)
+ offsety = paddle.randint(0, sideY - size + 1)
+ cutout = input[:, :, offsety:offsety + size,
+ offsetx:offsetx + size]
+ if i <= int(self.IC_Grey_P * self.InnerCrop):
+ cutout = gray(cutout)
+ cutout = resize(cutout, out_shape=output_shape)
+ cutouts.append(cutout)
+
+ cutouts = paddle.concat(cutouts)
+ if skip_augs is not True:
+ cutouts = self.augs(cutouts)
+ return cutouts
diff --git a/paddlenlp/transformers/guided_diffusion_utils/model_diffusion.py b/paddlenlp/transformers/guided_diffusion_utils/model_diffusion.py
new file mode 100644
index 000000000000..c9ae17bcb9f7
--- /dev/null
+++ b/paddlenlp/transformers/guided_diffusion_utils/model_diffusion.py
@@ -0,0 +1,96 @@
+'''
+This code is based on
+https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/script_util.py
+'''
+from .gaussian_diffusion import get_named_beta_schedule, SpacedDiffusion, space_timesteps, ModelVarType, ModelMeanType
+from .unet import UNetModel
+from .sec_diff import SecondaryDiffusionImageNet2
+
+NUM_CLASSES = 1000
+
+
+def create_unet_model(
+ image_size=512,
+ num_channels=256,
+ num_res_blocks=2,
+ channel_mult="",
+ learn_sigma=True,
+ class_cond=False,
+ attention_resolutions='32, 16, 8',
+ num_heads=4,
+ num_head_channels=64,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=True,
+ dropout=0.0,
+ resblock_updown=True,
+ use_new_attention_order=False,
+):
+ if channel_mult == "":
+ if image_size == 512:
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
+ elif image_size == 256:
+ channel_mult = (1, 1, 2, 2, 4, 4)
+ elif image_size == 128:
+ channel_mult = (1, 1, 2, 3, 4)
+ elif image_size == 64:
+ channel_mult = (1, 2, 3, 4)
+ else:
+ raise ValueError(f"unsupported image size: {image_size}")
+ else:
+ channel_mult = tuple(
+ int(ch_mult) for ch_mult in channel_mult.split(","))
+
+ attention_ds = []
+ for res in attention_resolutions.split(","):
+ attention_ds.append(image_size // int(res))
+
+ return UNetModel(
+ image_size=image_size,
+ in_channels=3,
+ model_channels=num_channels,
+ out_channels=(3 if not learn_sigma else 6),
+ num_res_blocks=num_res_blocks,
+ attention_resolutions=tuple(attention_ds),
+ dropout=dropout,
+ channel_mult=channel_mult,
+ num_classes=(NUM_CLASSES if class_cond else None),
+ use_fp16=False,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ num_heads_upsample=num_heads_upsample,
+ use_scale_shift_norm=use_scale_shift_norm,
+ resblock_updown=resblock_updown,
+ use_new_attention_order=use_new_attention_order,
+ )
+
+
+def create_secondary_model():
+ model = SecondaryDiffusionImageNet2()
+ return model
+
+
+def create_gaussian_diffusion(
+ steps=250,
+ learn_sigma=True,
+ sigma_small=False,
+ noise_schedule="linear",
+ predict_xstart=False,
+ rescale_timesteps=True,
+):
+ # propcess steps
+ timestep_respacing = f'ddim{steps}'
+ steps = (1000 // steps) * steps if steps < 1000 else steps
+
+ betas = get_named_beta_schedule(noise_schedule, steps)
+ if not timestep_respacing:
+ timestep_respacing = [steps]
+ return SpacedDiffusion(
+ use_timesteps=space_timesteps(steps, timestep_respacing),
+ betas=betas,
+ model_mean_type=(ModelMeanType.EPSILON
+ if not predict_xstart else ModelMeanType.START_X),
+ model_var_type=((ModelVarType.FIXED_LARGE
+ if not sigma_small else ModelVarType.FIXED_SMALL)
+ if not learn_sigma else ModelVarType.LEARNED_RANGE),
+ rescale_timesteps=rescale_timesteps,
+ )
diff --git a/paddlenlp/transformers/guided_diffusion_utils/perlin_noises.py b/paddlenlp/transformers/guided_diffusion_utils/perlin_noises.py
new file mode 100755
index 000000000000..fe1688974de8
--- /dev/null
+++ b/paddlenlp/transformers/guided_diffusion_utils/perlin_noises.py
@@ -0,0 +1,85 @@
+'''
+Perlin noise implementation by Paddle.
+This code is rewritten based on:
+https://github.com/jina-ai/discoart/blob/main/discoart/nn/perlin_noises.py
+'''
+import numpy as np
+import paddle
+import paddle.vision.transforms as PF
+from PIL import Image, ImageOps
+
+
+def interp(t):
+ return 3 * t**2 - 2 * t**3
+
+
+def perlin(width, height, scale=10):
+ gx, gy = paddle.randn([2, width + 1, height + 1, 1, 1])
+ xs = paddle.linspace(0, 1, scale + 1)[:-1, None]
+ ys = paddle.linspace(0, 1, scale + 1)[None, :-1]
+ wx = 1 - interp(xs)
+ wy = 1 - interp(ys)
+ dots = 0
+ dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)
+ dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)
+ dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))
+ dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] *
+ (1 - ys))
+ return dots.transpose([0, 2, 1, 3]).reshape([width * scale, height * scale])
+
+
+def perlin_ms(octaves, width, height, grayscale):
+ out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]
+ # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]
+ for i in range(1 if grayscale else 3):
+ scale = 2**len(octaves)
+ oct_width = width
+ oct_height = height
+ for oct in octaves:
+ p = perlin(oct_width, oct_height, scale)
+ out_array[i] += p * oct
+ scale //= 2
+ oct_width *= 2
+ oct_height *= 2
+ return paddle.concat(out_array)
+
+
+def create_perlin_noise(octaves, width, height, grayscale, side_y, side_x):
+ out = perlin_ms(octaves, width, height, grayscale)
+ if grayscale:
+ out = PF.resize(size=(side_y, side_x), img=out.numpy())
+ out = np.uint8(out)
+ out = Image.fromarray(out).convert('RGB')
+ else:
+ out = out.reshape([-1, 3, out.shape[0] // 3, out.shape[1]])
+ out = out.squeeze().transpose([1, 2, 0]).numpy()
+ out = PF.resize(size=(side_y, side_x), img=out)
+ out = out.clip(0, 1) * 255
+ out = np.uint8(out)
+ out = Image.fromarray(out)
+
+ out = ImageOps.autocontrast(out)
+ return out
+
+
+def regen_perlin(perlin_mode, side_y, side_x, batch_size):
+ if perlin_mode == 'color':
+ init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)], 1, 1,
+ False, side_y, side_x)
+ init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], 4, 4,
+ False, side_y, side_x)
+ elif perlin_mode == 'gray':
+ init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)], 1, 1,
+ True, side_y, side_x)
+ init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], 4, 4,
+ True, side_y, side_x)
+ else:
+ init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)], 1, 1,
+ False, side_y, side_x)
+ init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], 4, 4,
+ True, side_y, side_x)
+
+ init = (PF.to_tensor(init).add(PF.to_tensor(init2)).divide(
+ paddle.to_tensor(2.0)).unsqueeze(0) * 2 - 1)
+ del init2
+ return init.expand([batch_size, -1, -1, -1])
diff --git a/paddlenlp/transformers/guided_diffusion_utils/resize_right.py b/paddlenlp/transformers/guided_diffusion_utils/resize_right.py
new file mode 100755
index 000000000000..a65ed0b9116d
--- /dev/null
+++ b/paddlenlp/transformers/guided_diffusion_utils/resize_right.py
@@ -0,0 +1,477 @@
+from fractions import Fraction
+from math import ceil
+
+from math import pi
+
+import paddle
+import paddle.nn as nn
+import numpy
+import numpy as np
+
+nnModuleWrapped = nn.Layer
+
+
+def set_framework_dependencies(x):
+ if type(x) is numpy.ndarray:
+ to_dtype = lambda a: a
+ fw = numpy
+ else:
+ to_dtype = lambda a: paddle.cast(a, x.dtype)
+ fw = paddle
+ # eps = fw.finfo(fw.float32).eps
+ eps = paddle.to_tensor(np.finfo(np.float32).eps)
+ return fw, to_dtype, eps
+
+
+def support_sz(sz):
+
+ def wrapper(f):
+ f.support_sz = sz
+ return f
+
+ return wrapper
+
+
+@support_sz(4)
+def cubic(x):
+ fw, to_dtype, eps = set_framework_dependencies(x)
+ absx = fw.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +
+ (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) *
+ to_dtype((1. < absx) & (absx <= 2.)))
+
+
+@support_sz(4)
+def lanczos2(x):
+ fw, to_dtype, eps = set_framework_dependencies(x)
+ return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) /
+ ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2))
+
+
+@support_sz(6)
+def lanczos3(x):
+ fw, to_dtype, eps = set_framework_dependencies(x)
+ return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) /
+ ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3))
+
+
+@support_sz(2)
+def linear(x):
+ fw, to_dtype, eps = set_framework_dependencies(x)
+ return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) +
+ (1 - x) * to_dtype((0 <= x) & (x <= 1)))
+
+
+@support_sz(1)
+def box(x):
+ fw, to_dtype, eps = set_framework_dependencies(x)
+ return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1))
+
+
+def resize(input,
+ scale_factors=None,
+ out_shape=None,
+ interp_method=cubic,
+ support_sz=None,
+ antialiasing=True,
+ by_convs=False,
+ scale_tolerance=None,
+ max_numerator=10,
+ pad_mode='constant'):
+ # get properties of the input tensor
+ in_shape, n_dims = input.shape, input.ndim
+
+ # fw stands for framework that can be either numpy or paddle,
+ # determined by the input type
+ fw = numpy if type(input) is numpy.ndarray else paddle
+ eps = np.finfo(np.float32).eps if fw == numpy else paddle.to_tensor(
+ np.finfo(np.float32).eps)
+ device = input.place if fw is paddle else None
+
+ # set missing scale factors or output shapem one according to another,
+ # scream if both missing. this is also where all the defults policies
+ # take place. also handling the by_convs attribute carefully.
+ scale_factors, out_shape, by_convs = set_scale_and_out_sz(
+ in_shape, out_shape, scale_factors, by_convs, scale_tolerance,
+ max_numerator, eps, fw)
+
+ # sort indices of dimensions according to scale of each dimension.
+ # since we are going dim by dim this is efficient
+ sorted_filtered_dims_and_scales = [
+ (dim, scale_factors[dim], by_convs[dim], in_shape[dim], out_shape[dim])
+ for dim in sorted(range(n_dims), key=lambda ind: scale_factors[ind])
+ if scale_factors[dim] != 1.
+ ]
+ # unless support size is specified by the user, it is an attribute
+ # of the interpolation method
+ if support_sz is None:
+ support_sz = interp_method.support_sz
+
+ # output begins identical to input and changes with each iteration
+ output = input
+
+ # iterate over dims
+ for (dim, scale_factor, dim_by_convs, in_sz,
+ out_sz) in sorted_filtered_dims_and_scales:
+ # STEP 1- PROJECTED GRID: The non-integer locations of the projection
+ # of output pixel locations to the input tensor
+ projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw,
+ dim_by_convs, device)
+
+ # STEP 1.5: ANTIALIASING- If antialiasing is taking place, we modify
+ # the window size and the interpolation method (see inside function)
+ cur_interp_method, cur_support_sz = apply_antialiasing_if_needed(
+ interp_method, support_sz, scale_factor, antialiasing)
+
+ # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels
+ # that influence it. Also calculate needed padding and update grid
+ # accoedingly
+ field_of_view = get_field_of_view(projected_grid, cur_support_sz, fw,
+ eps, device)
+
+ # STEP 2.5- CALCULATE PAD AND UPDATE: according to the field of view,
+ # the input should be padded to handle the boundaries, coordinates
+ # should be updated. actual padding only occurs when weights are
+ # aplied (step 4). if using by_convs for this dim, then we need to
+ # calc right and left boundaries for each filter instead.
+ pad_sz, projected_grid, field_of_view = calc_pad_sz(
+ in_sz, out_sz, field_of_view, projected_grid, scale_factor,
+ dim_by_convs, fw, device)
+ # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in
+ # the field of view for each output pixel
+ weights = get_weights(cur_interp_method, projected_grid, field_of_view)
+
+ # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying
+ # its set of weights with the pixel values in its field of view.
+ # We now multiply the fields of view with their matching weights.
+ # We do this by tensor multiplication and broadcasting.
+ # if by_convs is true for this dim, then we do this action by
+ # convolutions. this is equivalent but faster.
+ if not dim_by_convs:
+ output = apply_weights(output, field_of_view, weights, dim, n_dims,
+ pad_sz, pad_mode, fw)
+ else:
+ output = apply_convs(output, scale_factor, in_sz, out_sz, weights,
+ dim, pad_sz, pad_mode, fw)
+ return output
+
+
+def get_projected_grid(in_sz, out_sz, scale_factor, fw, by_convs, device=None):
+ # we start by having the ouput coordinates which are just integer locations
+ # in the special case when usin by_convs, we only need two cycles of grid
+ # points. the first and last.
+ grid_sz = out_sz if not by_convs else scale_factor.numerator
+ out_coordinates = fw_arange(grid_sz, fw, device)
+
+ # This is projecting the ouput pixel locations in 1d to the input tensor,
+ # as non-integer locations.
+ # the following fomrula is derived in the paper
+ # "From Discrete to Continuous Convolutions" by Shocher et al.
+ return (out_coordinates / float(scale_factor) + (in_sz - 1) / 2 -
+ (out_sz - 1) / (2 * float(scale_factor)))
+
+
+def get_field_of_view(projected_grid, cur_support_sz, fw, eps, device):
+ # for each output pixel, map which input pixels influence it, in 1d.
+ # we start by calculating the leftmost neighbor, using half of the window
+ # size (eps is for when boundary is exact int)
+ left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw)
+
+ # then we simply take all the pixel centers in the field by counting
+ # window size pixels from the left boundary
+ ordinal_numbers = fw_arange(ceil(cur_support_sz - eps), fw, device)
+ return left_boundaries[:, None] + ordinal_numbers
+
+
+def calc_pad_sz(in_sz, out_sz, field_of_view, projected_grid, scale_factor,
+ dim_by_convs, fw, device):
+ if not dim_by_convs:
+ # determine padding according to neighbor coords out of bound.
+ # this is a generalized notion of padding, when pad<0 it means crop
+ pad_sz = [
+ -field_of_view[0, 0].item(),
+ field_of_view[-1, -1].item() - in_sz + 1
+ ]
+
+ # since input image will be changed by padding, coordinates of both
+ # field_of_view and projected_grid need to be updated
+ field_of_view += pad_sz[0]
+ projected_grid += pad_sz[0]
+
+ else:
+ # only used for by_convs, to calc the boundaries of each filter the
+ # number of distinct convolutions is the numerator of the scale factor
+ num_convs, stride = scale_factor.numerator, scale_factor.denominator
+
+ # calculate left and right boundaries for each conv. left can also be
+ # negative right can be bigger than in_sz. such cases imply padding if
+ # needed. however if# both are in-bounds, it means we need to crop,
+ # practically apply the conv only on part of the image.
+ left_pads = -field_of_view[:, 0]
+
+ # next calc is tricky, explanation by rows:
+ # 1) counting output pixels between the first position of each filter
+ # to the right boundary of the input
+ # 2) dividing it by number of filters to count how many 'jumps'
+ # each filter does
+ # 3) multiplying by the stride gives us the distance over the input
+ # coords done by all these jumps for each filter
+ # 4) to this distance we add the right boundary of the filter when
+ # placed in its leftmost position. so now we get the right boundary
+ # of that filter in input coord.
+ # 5) the padding size needed is obtained by subtracting the rightmost
+ # input coordinate. if the result is positive padding is needed. if
+ # negative then negative padding means shaving off pixel columns.
+ right_pads = (((out_sz - fw_arange(num_convs, fw, device) - 1) # (1)
+ // num_convs) # (2)
+ * stride # (3)
+ + field_of_view[:, -1] # (4)
+ - in_sz + 1) # (5)
+
+ # in the by_convs case pad_sz is a list of left-right pairs. one per
+ # each filter
+
+ pad_sz = list(zip(left_pads, right_pads))
+
+ return pad_sz, projected_grid, field_of_view
+
+
+def get_weights(interp_method, projected_grid, field_of_view):
+ # the set of weights per each output pixels is the result of the chosen
+ # interpolation method applied to the distances between projected grid
+ # locations and the pixel-centers in the field of view (distances are
+ # directed, can be positive or negative)
+ weights = interp_method(projected_grid[:, None] - field_of_view)
+
+ # we now carefully normalize the weights to sum to 1 per each output pixel
+ sum_weights = weights.sum(1, keepdim=True)
+ sum_weights[sum_weights == 0] = 1
+ return weights / sum_weights
+
+
+def apply_weights(input, field_of_view, weights, dim, n_dims, pad_sz, pad_mode,
+ fw):
+ # for this operation we assume the resized dim is the first one.
+ # so we transpose and will transpose back after multiplying
+ tmp_input = fw_swapaxes(input, dim, 0, fw)
+
+ # apply padding
+ tmp_input = fw_pad(tmp_input, fw, pad_sz, pad_mode)
+
+ # field_of_view is a tensor of order 2: for each output (1d location
+ # along cur dim)- a list of 1d neighbors locations.
+ # note that this whole operations is applied to each dim separately,
+ # this is why it is all in 1d.
+ # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1:
+ # for each output pixel (this time indicated in all dims), these are the
+ # values of the neighbors in the 1d field of view. note that we only
+ # consider neighbors along the current dim, but such set exists for every
+ # multi-dim location, hence the final tensor order is image_dims+1.
+ paddle.device.cuda.empty_cache()
+ neighbors = tmp_input[field_of_view]
+
+ # weights is an order 2 tensor: for each output location along 1d- a list
+ # of weights matching the field of view. we augment it with ones, for
+ # broadcasting, so that when multiplies some tensor the weights affect
+ # only its first dim.
+ tmp_weights = fw.reshape(weights, (*weights.shape, *[1] * (n_dims - 1)))
+
+ # now we simply multiply the weights with the neighbors, and then sum
+ # along the field of view, to get a single value per out pixel
+ tmp_output = (neighbors * tmp_weights).sum(1)
+ # we transpose back the resized dim to its original position
+ return fw_swapaxes(tmp_output, 0, dim, fw)
+
+
+def apply_convs(input, scale_factor, in_sz, out_sz, weights, dim, pad_sz,
+ pad_mode, fw):
+ # for this operations we assume the resized dim is the last one.
+ # so we transpose and will transpose back after multiplying
+ input = fw_swapaxes(input, dim, -1, fw)
+
+ # the stride for all convs is the denominator of the scale factor
+ stride, num_convs = scale_factor.denominator, scale_factor.numerator
+
+ # prepare an empty tensor for the output
+ tmp_out_shape = list(input.shape)
+ tmp_out_shape[-1] = out_sz
+ tmp_output = fw_empty(tuple(tmp_out_shape), fw, input.device)
+
+ # iterate over the conv operations. we have as many as the numerator
+ # of the scale-factor. for each we need boundaries and a filter.
+ for conv_ind, (pad_sz, filt) in enumerate(zip(pad_sz, weights)):
+ # apply padding (we pad last dim, padding can be negative)
+ pad_dim = input.ndim - 1
+ tmp_input = fw_pad(input, fw, pad_sz, pad_mode, dim=pad_dim)
+
+ # apply convolution over last dim. store in the output tensor with
+ # positional strides so that when the loop is comlete conv results are
+ # interwind
+ tmp_output[..., conv_ind::num_convs] = fw_conv(tmp_input, filt, stride)
+
+ return fw_swapaxes(tmp_output, -1, dim, fw)
+
+
+def set_scale_and_out_sz(in_shape, out_shape, scale_factors, by_convs,
+ scale_tolerance, max_numerator, eps, fw):
+ # eventually we must have both scale-factors and out-sizes for all in/out
+ # dims. however, we support many possible partial arguments
+ if scale_factors is None and out_shape is None:
+ raise ValueError("either scale_factors or out_shape should be "
+ "provided")
+ if out_shape is not None:
+ # if out_shape has less dims than in_shape, we defaultly resize the
+ # first dims for numpy and last dims for paddle
+ out_shape = (list(out_shape) + list(in_shape[len(out_shape):])
+ if fw is numpy else list(in_shape[:-len(out_shape)]) +
+ list(out_shape))
+ if scale_factors is None:
+ # if no scale given, we calculate it as the out to in ratio
+ # (not recomended)
+ scale_factors = [
+ out_sz / in_sz for out_sz, in_sz in zip(out_shape, in_shape)
+ ]
+ if scale_factors is not None:
+ # by default, if a single number is given as scale, we assume resizing
+ # two dims (most common are images with 2 spatial dims)
+ scale_factors = (scale_factors if isinstance(
+ scale_factors, (list, tuple)) else [scale_factors, scale_factors])
+ # if less scale_factors than in_shape dims, we defaultly resize the
+ # first dims for numpy and last dims for paddle
+ scale_factors = (list(scale_factors) + [1] *
+ (len(in_shape) - len(scale_factors)) if fw is numpy
+ else [1] * (len(in_shape) - len(scale_factors)) +
+ list(scale_factors))
+ if out_shape is None:
+ # when no out_shape given, it is calculated by multiplying the
+ # scale by the in_shape (not recomended)
+ out_shape = [
+ ceil(scale_factor * in_sz)
+ for scale_factor, in_sz in zip(scale_factors, in_shape)
+ ]
+ # next part intentionally after out_shape determined for stability
+ # we fix by_convs to be a list of truth values in case it is not
+ if not isinstance(by_convs, (list, tuple)):
+ by_convs = [by_convs] * len(out_shape)
+
+ # next loop fixes the scale for each dim to be either frac or float.
+ # this is determined by by_convs and by tolerance for scale accuracy.
+ for ind, (sf, dim_by_convs) in enumerate(zip(scale_factors, by_convs)):
+ # first we fractionaize
+ if dim_by_convs:
+ frac = Fraction(1 / sf).limit_denominator(max_numerator)
+ frac = Fraction(numerator=frac.denominator,
+ denominator=frac.numerator)
+
+ # if accuracy is within tolerance scale will be frac. if not, then
+ # it will be float and the by_convs attr will be set false for
+ # this dim
+ if scale_tolerance is None:
+ scale_tolerance = eps
+ if dim_by_convs and abs(frac - sf) < scale_tolerance:
+ scale_factors[ind] = frac
+ else:
+ scale_factors[ind] = float(sf)
+ by_convs[ind] = False
+
+ return scale_factors, out_shape, by_convs
+
+
+def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor,
+ antialiasing):
+ # antialiasing is "stretching" the field of view according to the scale
+ # factor (only for downscaling). this is low-pass filtering. this
+ # requires modifying both the interpolation (stretching the 1d
+ # function and multiplying by the scale-factor) and the window size.
+ scale_factor = float(scale_factor)
+ if scale_factor >= 1.0 or not antialiasing:
+ return interp_method, support_sz
+ cur_interp_method = (
+ lambda arg: scale_factor * interp_method(scale_factor * arg))
+ cur_support_sz = support_sz / scale_factor
+ return cur_interp_method, cur_support_sz
+
+
+def fw_ceil(x, fw):
+ if fw is numpy:
+ return fw.int_(fw.ceil(x))
+ else:
+ return paddle.cast(x.ceil(), dtype='int64')
+
+
+def fw_floor(x, fw):
+ if fw is numpy:
+ return fw.int_(fw.floor(x))
+ else:
+ return paddle.cast(x.floor(), dtype='int64')
+
+
+def fw_cat(x, fw):
+ if fw is numpy:
+ return fw.concatenate(x)
+ else:
+ return fw.concat(x)
+
+
+def fw_swapaxes(x, ax_1, ax_2, fw):
+ if fw is numpy:
+ return fw.swapaxes(x, ax_1, ax_2)
+ else:
+ if ax_1 == -1:
+ ax_1 = len(x.shape) - 1
+ if ax_2 == -1:
+ ax_2 = len(x.shape) - 1
+ perm0 = list(range(len(x.shape)))
+ temp = ax_1
+ perm0[temp] = ax_2
+ perm0[ax_2] = temp
+ return fw.transpose(x, perm0)
+
+
+def fw_pad(x, fw, pad_sz, pad_mode, dim=0):
+ if pad_sz == (0, 0):
+ return x
+ if fw is numpy:
+ pad_vec = [(0, 0)] * x.ndim
+ pad_vec[dim] = pad_sz
+ return fw.pad(x, pad_width=pad_vec, mode=pad_mode)
+ else:
+ if x.ndim < 3:
+ x = x[None, None, ...]
+
+ pad_vec = [0] * ((x.ndim - 2) * 2)
+ pad_vec[0:2] = pad_sz
+ return fw_swapaxes(
+ fw.nn.functional.pad(fw_swapaxes(x, dim, -1, fw),
+ pad=pad_vec,
+ mode=pad_mode), dim, -1, fw)
+
+
+def fw_conv(input, filter, stride):
+ # we want to apply 1d conv to any nd array. the way to do it is to reshape
+ # the input to a 4D tensor. first two dims are singeletons, 3rd dim stores
+ # all the spatial dims that we are not convolving along now. then we can
+ # apply conv2d with a 1xK filter. This convolves the same way all the other
+ # dims stored in the 3d dim. like depthwise conv over these.
+ # TODO: numpy support
+ reshaped_input = input.reshape(1, 1, -1, input.shape[-1])
+ reshaped_output = paddle.nn.functional.conv2d(reshaped_input,
+ filter.view(1, 1, 1, -1),
+ stride=(1, stride))
+ return reshaped_output.reshape(*input.shape[:-1], -1)
+
+
+def fw_arange(upper_bound, fw, device):
+ if fw is numpy:
+ return fw.arange(upper_bound)
+ else:
+ return fw.arange(upper_bound)
+
+
+def fw_empty(shape, fw, device):
+ if fw is numpy:
+ return fw.empty(shape)
+ else:
+ return fw.empty(shape=shape)
diff --git a/paddlenlp/transformers/guided_diffusion_utils/sec_diff.py b/paddlenlp/transformers/guided_diffusion_utils/sec_diff.py
new file mode 100644
index 000000000000..84875cc1779d
--- /dev/null
+++ b/paddlenlp/transformers/guided_diffusion_utils/sec_diff.py
@@ -0,0 +1,139 @@
+'''
+This code is rewritten by Paddle based on
+https://github.com/jina-ai/discoart/blob/main/discoart/nn/sec_diff.py
+'''
+import math
+from dataclasses import dataclass
+from functools import partial
+
+import paddle
+import paddle.nn as nn
+
+
+@dataclass
+class DiffusionOutput:
+ v: paddle.Tensor
+ pred: paddle.Tensor
+ eps: paddle.Tensor
+
+
+class SkipBlock(nn.Layer):
+
+ def __init__(self, main, skip=None):
+ super().__init__()
+ self.main = nn.Sequential(*main)
+ self.skip = skip if skip else nn.Identity()
+
+ def forward(self, input):
+ return paddle.concat([self.main(input), self.skip(input)], axis=1)
+
+
+def append_dims(x, n):
+ return x[(Ellipsis, *(None, ) * (n - x.ndim))]
+
+
+def expand_to_planes(x, shape):
+ return paddle.tile(append_dims(x, len(shape)), [1, 1, *shape[2:]])
+
+
+def alpha_sigma_to_t(alpha, sigma):
+ return paddle.atan2(sigma, alpha) * 2 / math.pi
+
+
+def t_to_alpha_sigma(t):
+ return paddle.cos(t * math.pi / 2), paddle.sin(t * math.pi / 2)
+
+
+class SecondaryDiffusionImageNet2(nn.Layer):
+
+ def __init__(self):
+ super().__init__()
+ c = 64 # The base channel count
+ cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8]
+
+ self.timestep_embed = FourierFeatures(1, 16)
+ self.down = nn.AvgPool2D(2)
+ self.up = nn.Upsample(scale_factor=2,
+ mode='bilinear',
+ align_corners=False)
+
+ self.net = nn.Sequential(
+ ConvBlock(3 + 16, cs[0]),
+ ConvBlock(cs[0], cs[0]),
+ SkipBlock([
+ self.down,
+ ConvBlock(cs[0], cs[1]),
+ ConvBlock(cs[1], cs[1]),
+ SkipBlock([
+ self.down,
+ ConvBlock(cs[1], cs[2]),
+ ConvBlock(cs[2], cs[2]),
+ SkipBlock([
+ self.down,
+ ConvBlock(cs[2], cs[3]),
+ ConvBlock(cs[3], cs[3]),
+ SkipBlock([
+ self.down,
+ ConvBlock(cs[3], cs[4]),
+ ConvBlock(cs[4], cs[4]),
+ SkipBlock([
+ self.down,
+ ConvBlock(cs[4], cs[5]),
+ ConvBlock(cs[5], cs[5]),
+ ConvBlock(cs[5], cs[5]),
+ ConvBlock(cs[5], cs[4]),
+ self.up,
+ ]),
+ ConvBlock(cs[4] * 2, cs[4]),
+ ConvBlock(cs[4], cs[3]),
+ self.up,
+ ]),
+ ConvBlock(cs[3] * 2, cs[3]),
+ ConvBlock(cs[3], cs[2]),
+ self.up,
+ ]),
+ ConvBlock(cs[2] * 2, cs[2]),
+ ConvBlock(cs[2], cs[1]),
+ self.up,
+ ]),
+ ConvBlock(cs[1] * 2, cs[1]),
+ ConvBlock(cs[1], cs[0]),
+ self.up,
+ ]),
+ ConvBlock(cs[0] * 2, cs[0]),
+ nn.Conv2D(cs[0], 3, 3, padding=1),
+ )
+
+ def forward(self, input, t):
+ timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]),
+ input.shape)
+ v = self.net(paddle.concat([input, timestep_embed], axis=1))
+ alphas, sigmas = map(partial(append_dims, n=v.ndim),
+ t_to_alpha_sigma(t))
+ pred = input * alphas - v * sigmas
+ eps = input * sigmas + v * alphas
+ return DiffusionOutput(v, pred, eps)
+
+
+class FourierFeatures(nn.Layer):
+
+ def __init__(self, in_features, out_features, std=1.0):
+ super().__init__()
+ assert out_features % 2 == 0
+ self.weight = paddle.create_parameter(
+ [out_features // 2, in_features],
+ dtype='float32',
+ default_initializer=nn.initializer.Normal(mean=0.0, std=std))
+
+ def forward(self, input):
+ f = 2 * math.pi * input @ self.weight.T
+ return paddle.concat([f.cos(), f.sin()], axis=-1)
+
+
+class ConvBlock(nn.Sequential):
+
+ def __init__(self, c_in, c_out):
+ super().__init__(
+ nn.Conv2D(c_in, c_out, 3, padding=1),
+ nn.ReLU(),
+ )
diff --git a/paddlenlp/transformers/guided_diffusion_utils/transforms.py b/paddlenlp/transformers/guided_diffusion_utils/transforms.py
new file mode 100755
index 000000000000..022be4688a92
--- /dev/null
+++ b/paddlenlp/transformers/guided_diffusion_utils/transforms.py
@@ -0,0 +1,806 @@
+'''
+This code is rewritten by Paddle based on
+https://github.com/pytorch/vision/blob/main/torchvision/transforms/transforms.py
+'''
+import math
+import numbers
+import warnings
+from enum import Enum
+from typing import List, Optional, Sequence, Tuple
+
+import paddle
+import paddle.nn as nn
+from paddle.nn.functional import grid_sample
+
+
+class Normalize(nn.Layer):
+
+ def __init__(self, mean, std):
+ super(Normalize, self).__init__()
+ self.mean = paddle.to_tensor(mean)
+ self.std = paddle.to_tensor(std)
+
+ def forward(self, tensor):
+ dtype = tensor.dtype
+ mean = paddle.cast(self.mean, dtype=dtype).reshape([1, -1, 1, 1])
+ std = paddle.cast(self.std, dtype=dtype).reshape([1, -1, 1, 1])
+ result = tensor.subtract(mean).divide(std)
+ return result
+
+
+class InterpolationMode(Enum):
+ """Interpolation modes
+ Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``.
+ """
+
+ NEAREST = "nearest"
+ BILINEAR = "bilinear"
+ BICUBIC = "bicubic"
+ # For PIL compatibility
+ BOX = "box"
+ HAMMING = "hamming"
+ LANCZOS = "lanczos"
+
+
+class Grayscale(nn.Layer):
+
+ def __init__(self, num_output_channels):
+ super(Grayscale, self).__init__()
+ self.num_output_channels = num_output_channels
+
+ def forward(self, x):
+ output = (0.2989 * x[:, 0:1, :, :] + 0.587 * x[:, 1:2, :, :] +
+ 0.114 * x[:, 2:3, :, :])
+ if self.num_output_channels == 3:
+ return output.expand(x.shape)
+
+ return output
+
+
+class Lambda(nn.Layer):
+
+ def __init__(self, func):
+ super(Lambda, self).__init__()
+ self.transform = func
+
+ def forward(self, x):
+ return self.transform(x)
+
+
+class RandomGrayscale(nn.Layer):
+
+ def __init__(self, p):
+ super(RandomGrayscale, self).__init__()
+ self.prob = p
+ self.transform = Grayscale(3)
+
+ def forward(self, x):
+ if paddle.rand([1]) < self.prob:
+ return self.transform(x)
+ else:
+ return x
+
+
+class RandomHorizontalFlip(nn.Layer):
+
+ def __init__(self, prob):
+ super(RandomHorizontalFlip, self).__init__()
+ self.prob = prob
+
+ def forward(self, x):
+ if paddle.rand([1]) < self.prob:
+ return x[:, :, :, ::-1]
+ else:
+ return x
+
+
+def _blend(img1, img2, ratio: float):
+ ratio = float(ratio)
+ bound = 1.0
+ return (ratio * img1 + (1.0 - ratio) * img2).clip(0, bound)
+
+
+def trunc_div(a, b):
+ ipt = paddle.divide(a, b)
+ sign_ipt = paddle.sign(ipt)
+ abs_ipt = paddle.abs(ipt)
+ abs_ipt = paddle.floor(abs_ipt)
+ out = paddle.multiply(sign_ipt, abs_ipt)
+ return out
+
+
+def fmod(a, b):
+ return a - trunc_div(a, b) * b
+
+
+def _rgb2hsv(img):
+ r, g, b = img.unbind(axis=-3)
+
+ # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
+ # src/libImaging/Convert.c#L330
+ maxc = paddle.max(img, axis=-3)
+ minc = paddle.min(img, axis=-3)
+
+ # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
+ # from happening in the results, because
+ # + S channel has division by `maxc`, which is zero only if `maxc = minc`
+ # + H channel has division by `(maxc - minc)`.
+ #
+ # Instead of overwriting NaN afterwards, we just prevent it from occuring so
+ # we don't need to deal with it in case we save the NaN in a buffer in
+ # backprop, if it is ever supported, but it doesn't hurt to do so.
+ eqc = maxc == minc
+
+ cr = maxc - minc
+ # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
+ ones = paddle.ones_like(maxc)
+ s = cr / paddle.where(eqc, ones, maxc)
+ # Note that `eqc => maxc = minc = r = g = b`. So the following calculation
+ # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
+ # would not matter what values `rc`, `gc`, and `bc` have here, and thus
+ # replacing denominator with 1 when `eqc` is fine.
+ cr_divisor = paddle.where(eqc, ones, cr)
+ rc = (maxc - r) / cr_divisor
+ gc = (maxc - g) / cr_divisor
+ bc = (maxc - b) / cr_divisor
+
+ hr = (maxc == r).cast('float32') * (bc - gc)
+ hg = ((maxc == g) & (maxc != r)).cast('float32') * (2.0 + rc - bc)
+ hb = ((maxc != g) & (maxc != r)).cast('float32') * (4.0 + gc - rc)
+ h = hr + hg + hb
+ h = fmod((h / 6.0 + 1.0), paddle.to_tensor(1.0))
+ return paddle.stack((h, s, maxc), axis=-3)
+
+
+def _hsv2rgb(img):
+ h, s, v = img.unbind(axis=-3)
+ i = paddle.floor(h * 6.0)
+ f = (h * 6.0) - i
+ i = i.cast(dtype='int32')
+
+ p = paddle.clip((v * (1.0 - s)), 0.0, 1.0)
+ q = paddle.clip((v * (1.0 - s * f)), 0.0, 1.0)
+ t = paddle.clip((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
+ i = i % 6
+
+ mask = i.unsqueeze(axis=-3) == paddle.arange(6).reshape([-1, 1, 1])
+
+ a1 = paddle.stack((v, q, p, p, t, v), axis=-3)
+ a2 = paddle.stack((t, v, v, q, p, p), axis=-3)
+ a3 = paddle.stack((p, p, t, v, v, q), axis=-3)
+ a4 = paddle.stack((a1, a2, a3), axis=-4)
+
+ return paddle.einsum("...ijk, ...xijk -> ...xjk",
+ mask.cast(dtype=img.dtype), a4)
+
+
+def adjust_brightness(img, brightness_factor: float):
+ if brightness_factor < 0:
+ raise ValueError(
+ f"brightness_factor ({brightness_factor}) is not non-negative.")
+
+ return _blend(img, paddle.zeros_like(img), brightness_factor)
+
+
+def adjust_contrast(img, contrast_factor: float):
+ if contrast_factor < 0:
+ raise ValueError(
+ f"contrast_factor ({contrast_factor}) is not non-negative.")
+
+ c = img.shape[1]
+
+ if c == 3:
+ output = (0.2989 * img[:, 0:1, :, :] + 0.587 * img[:, 1:2, :, :] +
+ 0.114 * img[:, 2:3, :, :])
+ mean = paddle.mean(output, axis=(-3, -2, -1), keepdim=True)
+
+ else:
+ mean = paddle.mean(img, axis=(-3, -2, -1), keepdim=True)
+
+ return _blend(img, mean, contrast_factor)
+
+
+def adjust_hue(img, hue_factor: float):
+ if not (-0.5 <= hue_factor <= 0.5):
+ raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
+
+ img = _rgb2hsv(img)
+ h, s, v = img.unbind(axis=-3)
+ h = fmod(h + hue_factor, paddle.to_tensor(1.0))
+ img = paddle.stack((h, s, v), axis=-3)
+ img_hue_adj = _hsv2rgb(img)
+ return img_hue_adj
+
+
+def adjust_saturation(img, saturation_factor: float):
+ if saturation_factor < 0:
+ raise ValueError(
+ f"saturation_factor ({saturation_factor}) is not non-negative.")
+
+ output = (0.2989 * img[:, 0:1, :, :] + 0.587 * img[:, 1:2, :, :] +
+ 0.114 * img[:, 2:3, :, :])
+
+ return _blend(img, output, saturation_factor)
+
+
+class ColorJitter(nn.Layer):
+
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
+ super(ColorJitter, self).__init__()
+ self.brightness = self._check_input(brightness, "brightness")
+ self.contrast = self._check_input(contrast, "contrast")
+ self.saturation = self._check_input(saturation, "saturation")
+ self.hue = self._check_input(hue,
+ "hue",
+ center=0,
+ bound=(-0.5, 0.5),
+ clip_first_on_zero=False)
+
+ def _check_input(self,
+ value,
+ name,
+ center=1,
+ bound=(0, float("inf")),
+ clip_first_on_zero=True):
+ if isinstance(value, numbers.Number):
+ if value < 0:
+ raise ValueError(
+ f"If {name} is a single number, it must be non negative.")
+ value = [center - float(value), center + float(value)]
+ if clip_first_on_zero:
+ value[0] = max(value[0], 0.0)
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
+ raise ValueError(f"{name} values should be between {bound}")
+ else:
+ raise TypeError(
+ f"{name} should be a single number or a list/tuple with length 2."
+ )
+
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
+ # or (0., 0.) for hue, do nothing
+ if value[0] == value[1] == center:
+ value = None
+ return value
+
+ @staticmethod
+ def get_params(
+ brightness: Optional[List[float]],
+ contrast: Optional[List[float]],
+ saturation: Optional[List[float]],
+ hue: Optional[List[float]],
+ ) -> Tuple[paddle.Tensor, Optional[float], Optional[float], Optional[float],
+ Optional[float]]:
+ """Get the parameters for the randomized transform to be applied on image.
+
+ Args:
+ brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
+ uniformly. Pass None to turn off the transformation.
+ contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
+ uniformly. Pass None to turn off the transformation.
+ saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
+ uniformly. Pass None to turn off the transformation.
+ hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
+ Pass None to turn off the transformation.
+
+ Returns:
+ tuple: The parameters used to apply the randomized transform
+ along with their random order.
+ """
+ fn_idx = paddle.randperm(4)
+
+ b = None if brightness is None else paddle.empty([1]).uniform_(
+ brightness[0], brightness[1])
+ c = None if contrast is None else paddle.empty([1]).uniform_(
+ contrast[0], contrast[1])
+ s = None if saturation is None else paddle.empty([1]).uniform_(
+ saturation[0], saturation[1])
+ h = None if hue is None else paddle.empty([1]).uniform_(hue[0], hue[1])
+
+ return fn_idx, b, c, s, h
+
+ def forward(self, img):
+ """
+ Args:
+ img (PIL Image or Tensor): Input image.
+
+ Returns:
+ PIL Image or Tensor: Color jittered image.
+ """
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
+ self.brightness, self.contrast, self.saturation, self.hue)
+
+ for fn_id in fn_idx:
+ if fn_id == 0 and brightness_factor is not None:
+ img = adjust_brightness(img, brightness_factor)
+ elif fn_id == 1 and contrast_factor is not None:
+ img = adjust_contrast(img, contrast_factor)
+ elif fn_id == 2 and saturation_factor is not None:
+ img = adjust_saturation(img, saturation_factor)
+ elif fn_id == 3 and hue_factor is not None:
+ img = adjust_hue(img, hue_factor)
+
+ return img
+
+ def __repr__(self) -> str:
+ s = (f"{self.__class__.__name__}("
+ f"brightness={self.brightness}"
+ f", contrast={self.contrast}"
+ f", saturation={self.saturation}"
+ f", hue={self.hue})")
+ return s
+
+
+def _apply_grid_transform(img, grid, mode: str, fill: Optional[List[float]]):
+
+ if img.shape[0] > 1:
+ # Apply same grid to a batch of images
+ grid = grid.expand(
+ [img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3]])
+
+ # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
+ if fill is not None:
+ dummy = paddle.ones((img.shape[0], 1, img.shape[2], img.shape[3]),
+ dtype=img.dtype)
+ img = paddle.concat((img, dummy), axis=1)
+
+ img = grid_sample(img,
+ grid,
+ mode=mode,
+ padding_mode="zeros",
+ align_corners=False)
+
+ # Fill with required color
+ if fill is not None:
+ mask = img[:, -1:, :, :] # N * 1 * H * W
+ img = img[:, :-1, :, :] # N * C * H * W
+ mask = mask.expand_as(img)
+ len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1
+ fill_img = paddle.to_tensor(fill, dtype=img.dtype).reshape(
+ [1, len_fill, 1, 1]).expand_as(img)
+ if mode == "nearest":
+ mask = mask < 0.5
+ img[mask] = fill_img[mask]
+ else: # 'bilinear'
+ img = img * mask + (1.0 - mask) * fill_img
+ return img
+
+
+def _gen_affine_grid(
+ theta,
+ w: int,
+ h: int,
+ ow: int,
+ oh: int,
+):
+ # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
+ # AffineGridGenerator.cpp#L18
+ # Difference with AffineGridGenerator is that:
+ # 1) we normalize grid values after applying theta
+ # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
+
+ d = 0.5
+ base_grid = paddle.empty([1, oh, ow, 3], dtype=theta.dtype)
+ x_grid = paddle.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, num=ow)
+ base_grid[..., 0] = (x_grid)
+ y_grid = paddle.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1,
+ num=oh).unsqueeze_(-1)
+ base_grid[..., 1] = (y_grid)
+ base_grid[..., 2] = 1.0
+ rescaled_theta = theta.transpose([0, 2, 1]) / paddle.to_tensor(
+ [0.5 * w, 0.5 * h], dtype=theta.dtype)
+ output_grid = base_grid.reshape([1, oh * ow, 3]).bmm(rescaled_theta)
+ return output_grid.reshape([1, oh, ow, 2])
+
+
+def affine_impl(img,
+ matrix: List[float],
+ interpolation: str = "nearest",
+ fill: Optional[List[float]] = None):
+ theta = paddle.to_tensor(matrix, dtype=img.dtype).reshape([1, 2, 3])
+ shape = img.shape
+ # grid will be generated on the same device as theta and img
+ grid = _gen_affine_grid(theta,
+ w=shape[-1],
+ h=shape[-2],
+ ow=shape[-1],
+ oh=shape[-2])
+ return _apply_grid_transform(img, grid, interpolation, fill=fill)
+
+
+def _get_inverse_affine_matrix(center: List[float],
+ angle: float,
+ translate: List[float],
+ scale: float,
+ shear: List[float],
+ inverted: bool = True) -> List[float]:
+ # Helper method to compute inverse matrix for affine transformation
+
+ # Pillow requires inverse affine transformation matrix:
+ # Affine matrix is : M = T * C * RotateScaleShear * C^-1
+ #
+ # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
+ # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
+ # RotateScaleShear is rotation with scale and shear matrix
+ #
+ # RotateScaleShear(a, s, (sx, sy)) =
+ # = R(a) * S(s) * SHy(sy) * SHx(sx)
+ # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
+ # [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
+ # [ 0 , 0 , 1 ]
+ # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
+ # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
+ # [0, 1 ] [-tan(s), 1]
+ #
+ # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
+
+ rot = math.radians(angle)
+ sx = math.radians(shear[0])
+ sy = math.radians(shear[1])
+
+ cx, cy = center
+ tx, ty = translate
+
+ # RSS without scaling
+ a = math.cos(rot - sy) / math.cos(sy)
+ b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
+ c = math.sin(rot - sy) / math.cos(sy)
+ d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
+
+ if inverted:
+ # Inverted rotation matrix with scale and shear
+ # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
+ matrix = [d, -b, 0.0, -c, a, 0.0]
+ matrix = [x / scale for x in matrix]
+ # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
+ matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
+ matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
+ # Apply center translation: C * RSS^-1 * C^-1 * T^-1
+ matrix[2] += cx
+ matrix[5] += cy
+ else:
+ matrix = [a, b, 0.0, c, d, 0.0]
+ matrix = [x * scale for x in matrix]
+ # Apply inverse of center translation: RSS * C^-1
+ matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
+ matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
+ # Apply translation and center : T * C * RSS * C^-1
+ matrix[2] += cx + tx
+ matrix[5] += cy + ty
+
+ return matrix
+
+
+def affine(
+ img,
+ angle: float,
+ translate: List[int],
+ scale: float,
+ shear: List[float],
+ interpolation: InterpolationMode = InterpolationMode.NEAREST,
+ fill: Optional[List[float]] = None,
+ resample: Optional[int] = None,
+ fillcolor: Optional[List[float]] = None,
+ center: Optional[List[int]] = None,
+):
+ """Apply affine transformation on the image keeping image center invariant.
+ If the image is paddle Tensor, it is expected
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
+
+ Args:
+ img (PIL Image or Tensor): image to transform.
+ angle (number): rotation angle in degrees between -180 and 180, clockwise direction.
+ translate (sequence of integers): horizontal and vertical translations (post-rotation translation)
+ scale (float): overall scale
+ shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction.
+ If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while
+ the second value corresponds to a shear parallel to the y axis.
+ interpolation (InterpolationMode): Desired interpolation enum defined by
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
+ For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
+ but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
+ image. If given a number, the value is used for all bands respectively.
+
+ .. note::
+ In torchscript mode single int/float value is not supported, please use a sequence
+ of length 1: ``[value, ]``.
+ fillcolor (sequence or number, optional):
+ .. warning::
+ This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``fill`` instead.
+ resample (int, optional):
+ .. warning::
+ This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``interpolation``
+ instead.
+ center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
+ Default is the center of the image.
+
+ Returns:
+ PIL Image or Tensor: Transformed image.
+ """
+
+ # Backward compatibility with integer value
+ if isinstance(interpolation, int):
+ warnings.warn(
+ "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
+ "Please use InterpolationMode enum.")
+ interpolation = _interpolation_modes_from_int(interpolation)
+
+ if fillcolor is not None:
+ warnings.warn(
+ "The parameter 'fillcolor' is deprecated since 0.12 and will be removed in 0.14. "
+ "Please use 'fill' instead.")
+ fill = fillcolor
+
+ if not isinstance(angle, (int, float)):
+ raise TypeError("Argument angle should be int or float")
+
+ if not isinstance(translate, (list, tuple)):
+ raise TypeError("Argument translate should be a sequence")
+
+ if len(translate) != 2:
+ raise ValueError("Argument translate should be a sequence of length 2")
+
+ if scale <= 0.0:
+ raise ValueError("Argument scale should be positive")
+
+ if not isinstance(shear, (numbers.Number, (list, tuple))):
+ raise TypeError(
+ "Shear should be either a single value or a sequence of two values")
+
+ if not isinstance(interpolation, InterpolationMode):
+ raise TypeError("Argument interpolation should be a InterpolationMode")
+
+ if isinstance(angle, int):
+ angle = float(angle)
+
+ if isinstance(translate, tuple):
+ translate = list(translate)
+
+ if isinstance(shear, numbers.Number):
+ shear = [shear, 0.0]
+
+ if isinstance(shear, tuple):
+ shear = list(shear)
+
+ if len(shear) == 1:
+ shear = [shear[0], shear[0]]
+
+ if len(shear) != 2:
+ raise ValueError(
+ f"Shear should be a sequence containing two values. Got {shear}")
+
+ if center is not None and not isinstance(center, (list, tuple)):
+ raise TypeError("Argument center should be a sequence")
+ center_f = [0.0, 0.0]
+ if center is not None:
+ _, height, width = img.shape[0], img.shape[1], img.shape[2]
+ # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
+ center_f = [
+ 1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])
+ ]
+
+ translate_f = [1.0 * t for t in translate]
+ matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale,
+ shear)
+ return affine_impl(img,
+ matrix=matrix,
+ interpolation=interpolation.value,
+ fill=fill)
+
+
+def _interpolation_modes_from_int(i: int) -> InterpolationMode:
+ inverse_modes_mapping = {
+ 0: InterpolationMode.NEAREST,
+ 2: InterpolationMode.BILINEAR,
+ 3: InterpolationMode.BICUBIC,
+ 4: InterpolationMode.BOX,
+ 5: InterpolationMode.HAMMING,
+ 1: InterpolationMode.LANCZOS,
+ }
+ return inverse_modes_mapping[i]
+
+
+def _check_sequence_input(x, name, req_sizes):
+ msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join(
+ [str(s) for s in req_sizes])
+ if not isinstance(x, Sequence):
+ raise TypeError(f"{name} should be a sequence of length {msg}.")
+ if len(x) not in req_sizes:
+ raise ValueError(f"{name} should be sequence of length {msg}.")
+
+
+def _setup_angle(x, name, req_sizes=(2, )):
+ if isinstance(x, numbers.Number):
+ if x < 0:
+ raise ValueError(
+ f"If {name} is a single number, it must be positive.")
+ x = [-x, x]
+ else:
+ _check_sequence_input(x, name, req_sizes)
+
+ return [float(d) for d in x]
+
+
+class RandomAffine(nn.Layer):
+ """Random affine transformation of the image keeping center invariant.
+ If the image is paddle Tensor, it is expected
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
+
+ Args:
+ degrees (sequence or number): Range of degrees to select from.
+ If degrees is a number instead of sequence like (min, max), the range of degrees
+ will be (-degrees, +degrees). Set to 0 to deactivate rotations.
+ translate (tuple, optional): tuple of maximum absolute fraction for horizontal
+ and vertical translations. For example translate=(a, b), then horizontal shift
+ is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
+ randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
+ scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
+ randomly sampled from the range a <= scale <= b. Will keep original scale by default.
+ shear (sequence or number, optional): Range of degrees to select from.
+ If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
+ will be applied. Else if shear is a sequence of 2 values a shear parallel to the x axis in the
+ range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
+ a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
+ Will not apply shear by default.
+ interpolation (InterpolationMode): Desired interpolation enum defined by
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
+ For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
+ but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
+ fill (sequence or number): Pixel fill value for the area outside the transformed
+ image. Default is ``0``. If given a number, the value is used for all bands respectively.
+ fillcolor (sequence or number, optional):
+ .. warning::
+ This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``fill`` instead.
+ resample (int, optional):
+ .. warning::
+ This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``interpolation``
+ instead.
+ center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
+ Default is the center of the image.
+
+ .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
+
+ """
+
+ def __init__(
+ self,
+ degrees,
+ translate=None,
+ scale=None,
+ shear=None,
+ interpolation=InterpolationMode.NEAREST,
+ fill=0,
+ fillcolor=None,
+ resample=None,
+ center=None,
+ ):
+ super(RandomAffine, self).__init__()
+ if resample is not None:
+ warnings.warn(
+ "The parameter 'resample' is deprecated since 0.12 and will be removed in 0.14. "
+ "Please use 'interpolation' instead.")
+ interpolation = _interpolation_modes_from_int(resample)
+
+ # Backward compatibility with integer value
+ if isinstance(interpolation, int):
+ warnings.warn(
+ "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
+ "Please use InterpolationMode enum.")
+ interpolation = _interpolation_modes_from_int(interpolation)
+
+ if fillcolor is not None:
+ warnings.warn(
+ "The parameter 'fillcolor' is deprecated since 0.12 and will be removed in 0.14. "
+ "Please use 'fill' instead.")
+ fill = fillcolor
+
+ self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))
+
+ if translate is not None:
+ _check_sequence_input(translate, "translate", req_sizes=(2, ))
+ for t in translate:
+ if not (0.0 <= t <= 1.0):
+ raise ValueError(
+ "translation values should be between 0 and 1")
+ self.translate = translate
+
+ if scale is not None:
+ _check_sequence_input(scale, "scale", req_sizes=(2, ))
+ for s in scale:
+ if s <= 0:
+ raise ValueError("scale values should be positive")
+ self.scale = scale
+
+ if shear is not None:
+ self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
+ else:
+ self.shear = shear
+
+ self.resample = self.interpolation = interpolation
+
+ if fill is None:
+ fill = 0
+ elif not isinstance(fill, (Sequence, numbers.Number)):
+ raise TypeError("Fill should be either a sequence or a number.")
+
+ self.fillcolor = self.fill = fill
+
+ if center is not None:
+ _check_sequence_input(center, "center", req_sizes=(2, ))
+
+ self.center = center
+
+ @staticmethod
+ def get_params(
+ degrees: List[float],
+ translate: Optional[List[float]],
+ scale_ranges: Optional[List[float]],
+ shears: Optional[List[float]],
+ img_size: List[int],
+ ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
+ """Get parameters for affine transformation
+
+ Returns:
+ params to be passed to the affine transformation
+ """
+ angle = float(
+ paddle.empty([1]).uniform_(float(degrees[0]), float(degrees[1])))
+ if translate is not None:
+ max_dx = float(translate[0] * img_size[0])
+ max_dy = float(translate[1] * img_size[1])
+ tx = int(float(paddle.empty([1]).uniform_(-max_dx, max_dx)))
+ ty = int(float(paddle.empty([1]).uniform_(-max_dy, max_dy)))
+ translations = (tx, ty)
+ else:
+ translations = (0, 0)
+
+ if scale_ranges is not None:
+ scale = float(
+ paddle.empty([1]).uniform_(scale_ranges[0], scale_ranges[1]))
+ else:
+ scale = 1.0
+
+ shear_x = shear_y = 0.0
+ if shears is not None:
+ shear_x = float(paddle.empty([1]).uniform_(shears[0], shears[1]))
+ if len(shears) == 4:
+ shear_y = float(
+ paddle.empty([1]).uniform_(shears[2], shears[3]))
+
+ shear = (shear_x, shear_y)
+
+ return angle, translations, scale, shear
+
+ def forward(self, img):
+ fill = self.fill
+ channels, height, width = img.shape[1], img.shape[2], img.shape[3]
+ if isinstance(fill, (int, float)):
+ fill = [float(fill)] * channels
+ else:
+ fill = [float(f) for f in fill]
+
+ img_size = [width, height] # flip for keeping BC on get_params call
+
+ ret = self.get_params(self.degrees, self.translate, self.scale,
+ self.shear, img_size)
+
+ return affine(img,
+ *ret,
+ interpolation=self.interpolation,
+ fill=fill,
+ center=self.center)
+
+ def __repr__(self) -> str:
+ s = f"{self.__class__.__name__}(degrees={self.degrees}"
+ s += f", translate={self.translate}" if self.translate is not None else ""
+ s += f", scale={self.scale}" if self.scale is not None else ""
+ s += f", shear={self.shear}" if self.shear is not None else ""
+ s += f", interpolation={self.interpolation.value}" if self.interpolation != InterpolationMode.NEAREST else ""
+ s += f", fill={self.fill}" if self.fill != 0 else ""
+ s += f", center={self.center}" if self.center is not None else ""
+ s += ")"
+
+ return s
diff --git a/paddlenlp/transformers/guided_diffusion_utils/unet.py b/paddlenlp/transformers/guided_diffusion_utils/unet.py
new file mode 100755
index 000000000000..67f8697830a5
--- /dev/null
+++ b/paddlenlp/transformers/guided_diffusion_utils/unet.py
@@ -0,0 +1,723 @@
+'''
+This code is rewritten by Paddle based on
+https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py
+'''
+import math
+from abc import abstractmethod
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class GroupNorm32(nn.GroupNorm):
+
+ def forward(self, x):
+ return super().forward(x)
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1D(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2D(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3D(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1D(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2D(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3D(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def update_ema(target_params, source_params, rate=0.99):
+ """
+ Update target parameters to be closer to those of source parameters using
+ an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(axis=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = paddle.exp(-math.log(max_period) *
+ paddle.arange(start=0, end=half, dtype=paddle.float32) /
+ half)
+ args = paddle.cast(timesteps[:, None], 'float32') * freqs[None]
+ embedding = paddle.concat([paddle.cos(args), paddle.sin(args)], axis=-1)
+ if dim % 2:
+ embedding = paddle.concat(
+ [embedding, paddle.zeros_like(embedding[:, :1])], axis=-1)
+ return embedding
+
+
+class AttentionPool2d(nn.Layer):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = self.create_parameter(
+ paddle.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c = x.shape[:2]
+ x = paddle.reshape(x, [b, c, -1])
+ x = paddle.concat([x.mean(dim=-1, keepdim=True), x],
+ axis=-1) # NC(HW+1)
+ x = x + paddle.cast(self.positional_embedding[None, :, :],
+ x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Layer):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Layer):
+ """
+ An upsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ padding=1)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
+ mode="nearest")
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Layer):
+ """
+ A downsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=1)
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.Silu(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.Silu(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels
+ if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.Silu(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims,
+ self.out_channels,
+ self.out_channels,
+ 3,
+ padding=1)),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(dims,
+ channels,
+ self.out_channels,
+ 3,
+ padding=1)
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb)
+ emb_out = paddle.cast(emb_out, h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = paddle.chunk(emb_out, 2, axis=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Layer):
+ """
+ An attention block that allows spatial positions to attend to each other.
+
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ b, c, *spatial = x.shape
+ x = paddle.reshape(x, [b, c, -1])
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return paddle.reshape(x + h, [b, c, *spatial])
+
+
+class QKVAttentionLegacy(nn.Layer):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = paddle.reshape(
+ qkv, [bs * self.n_heads, ch * 3, length]).split(3, axis=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = paddle.einsum(
+ "bct,bcs->bts", q * scale,
+ k * scale) # More stable with f16 than dividing afterwards
+ weight = paddle.cast(F.softmax(paddle.cast(weight, 'float32'), axis=-1),
+ weight.dtype)
+ a = paddle.einsum("bts,bcs->bct", weight, v)
+
+ return paddle.reshape(a, [bs, -1, length])
+
+
+class QKVAttention(nn.Layer):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, axis=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = paddle.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = paddle.cast(F.softmax(paddle.cast(weight, 'float32'), axis=-1),
+ weight.dtype)
+ a = paddle.einsum("bts,bcs->bct", weight,
+ v.reshape(bs * self.n_heads, ch, length))
+ return paddle.reshape(a, [bs, -1, length])
+
+
+class UNetModel(nn.Layer):
+ """
+ The full UNet model with attention and timestep embedding.
+
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.dtype = paddle.float16 if use_fp16 else paddle.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.Silu(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ ch = input_ch = int(channel_mult[0] * model_channels)
+ self.input_blocks = nn.LayerList([
+ TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3,
+ padding=1))
+ ])
+ self._feature_size = ch
+ input_block_chans = [ch]
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=int(mult * model_channels),
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(mult * model_channels)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ))
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ ) if resblock_updown else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch)))
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.LayerList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=int(model_channels * mult),
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(model_channels * mult)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ num_heads=num_heads_upsample,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ))
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ ) if resblock_updown else Upsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch))
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.Silu(),
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
+ )
+
+ def forward(self, x, timesteps, y=None):
+ """
+ Apply the model to an input batch.
+
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+
+ hs = []
+ emb = self.time_embed(timestep_embedding(timesteps,
+ self.model_channels))
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0], )
+ emb = emb + self.label_emb(y)
+
+ h = paddle.cast(x, self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ hs.append(h)
+ h = self.middle_block(h, emb)
+ for module in self.output_blocks:
+ h = paddle.concat([h, hs.pop()], axis=1)
+ h = module(h, emb)
+ return self.out(h)
+
+
+class SuperResModel(UNetModel):
+ """
+ A UNetModel that performs super-resolution.
+
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
+ """
+
+ def __init__(self, image_size, in_channels, *args, **kwargs):
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
+
+ def forward(self, x, timesteps, low_res=None, **kwargs):
+ _, _, new_height, new_width = x.shape
+ upsampled = F.interpolate(low_res, (new_height, new_width),
+ mode="bilinear")
+ x = paddle.concat([x, upsampled], axis=1)
+ return super().forward(x, timesteps, **kwargs)
diff --git a/paddlenlp/transformers/guided_diffusion_utils/utils.py b/paddlenlp/transformers/guided_diffusion_utils/utils.py
new file mode 100644
index 000000000000..e4c0478dbfc0
--- /dev/null
+++ b/paddlenlp/transformers/guided_diffusion_utils/utils.py
@@ -0,0 +1,442 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+'''
+This code is rewritten by Paddle based on Jina-ai/discoart.
+https://github.com/jina-ai/discoart/blob/main/discoart/runner.py
+'''
+import paddle
+import gc
+import random
+import numpy as np
+import paddle
+import paddle.vision.transforms as T
+
+from PIL import Image
+from pathlib import Path
+from paddle.utils import try_import
+from .losses import range_loss, spherical_dist_loss, tv_loss
+from .make_cutouts import MakeCutoutsDango
+from .sec_diff import alpha_sigma_to_t
+from .transforms import Normalize
+from .perlin_noises import create_perlin_noise, regen_perlin
+import random
+from ..image_utils import load_image
+
+__all__ = ["DiscoDiffusionMixin"]
+
+
+def set_seed(seed):
+ np.random.seed(seed)
+ random.seed(seed)
+ paddle.seed(seed)
+
+
+class DiscoDiffusionMixin:
+
+ def disco_diffusion_generate(self,
+ target_text_embeds,
+ init_image=None,
+ output_dir='outputs/',
+ width_height=[1280, 768],
+ skip_steps=0,
+ cut_ic_pow=1,
+ init_scale=1000,
+ clip_guidance_scale=5000,
+ tv_scale=0,
+ range_scale=0,
+ sat_scale=0,
+ cutn_batches=4,
+ perlin_init=False,
+ perlin_mode='mixed',
+ seed=None,
+ eta=0.8,
+ clamp_grad=True,
+ clamp_max=0.05,
+ cut_overview='[12]*400+[4]*600',
+ cut_innercut='[4]*400+[12]*600',
+ cut_icgray_p='[0.2]*400+[0]*600',
+ save_rate=10,
+ n_batches=1,
+ batch_name="",
+ use_secondary_model=True,
+ randomize_class=True,
+ clip_denoised=False,
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
+ image_std=[0.26862954, 0.26130258,
+ 0.27577711]):
+ r'''
+ The DiffusionMixin diffusion_generate method.
+
+ Args:
+ init_image (Path, optional):
+ Recall that in the image sequence above, the first image shown is just noise. If an init_image
+ is provided, diffusion will replace the noise with the init_image as its starting state. To use
+ an init_image, upload the image to the Colab instance or your Google Drive, and enter the full
+ image path here. If using an init_image, you may need to increase skip_steps to ~ 50% of total
+ steps to retain the character of the init. See skip_steps above for further discussion.
+ Default to `None`.
+ output_dir (Path, optional):
+ Output directory.
+ Default to `disco_diffusion_clip_vitb32_out`.
+ width_height (List[int, int], optional):
+ Desired final image size, in pixels. You can have a square, wide, or tall image, but each edge
+ length should be set to a multiple of 64px, and a minimum of 512px on the default CLIP model setting.
+ If you forget to use multiples of 64px in your dimensions, DD will adjust the dimensions of your
+ image to make it so.
+ Default to `[1280, 768]`.
+ skip_steps (int, optional):
+ Consider the chart shown here. Noise scheduling (denoise strength) starts very high and progressively
+ gets lower and lower as diffusion steps progress. The noise levels in the first few steps are very high,
+ so images change dramatically in early steps.As DD moves along the curve, noise levels (and thus the
+ amount an image changes per step) declines, and image coherence from one step to the next increases.
+ The first few steps of denoising are often so dramatic that some steps (maybe 10-15% of total) can be
+ skipped without affecting the final image. You can experiment with this as a way to cut render times.
+ If you skip too many steps, however, the remaining noise may not be high enough to generate new content,
+ and thus may not have time left to finish an image satisfactorily.Also, depending on your other settings,
+ you may need to skip steps to prevent CLIP from overshooting your goal, resulting in blown out colors
+ (hyper saturated, solid white, or solid black regions) or otherwise poor image quality. Consider that
+ the denoising process is at its strongest in the early steps, so skipping steps can sometimes mitigate
+ other problems.Lastly, if using an init_image, you will need to skip ~50% of the diffusion steps to retain
+ the shapes in the original init image. However, if you're using an init_image, you can also adjust
+ skip_steps up or down for creative reasons. With low skip_steps you can get a result "inspired by"
+ the init_image which will retain the colors and rough layout and shapes but look quite different.
+ With high skip_steps you can preserve most of the init_image contents and just do fine tuning of the texture.
+ Default to `0`.
+ steps:
+ When creating an image, the denoising curve is subdivided into steps for processing. Each step (or iteration)
+ involves the AI looking at subsets of the image called 'cuts' and calculating the 'direction' the image
+ should be guided to be more like the prompt. Then it adjusts the image with the help of the diffusion denoiser,
+ and moves to the next step.Increasing steps will provide more opportunities for the AI to adjust the image,
+ and each adjustment will be smaller, and thus will yield a more precise, detailed image. Increasing steps
+ comes at the expense of longer render times. Also, while increasing steps should generally increase image
+ quality, there is a diminishing return on additional steps beyond 250 - 500 steps. However, some intricate
+ images can take 1000, 2000, or more steps. It is really up to the user. Just know that the render time is
+ directly related to the number of steps, and many other parameters have a major impact on image quality, without
+ costing additional time.
+ cut_ic_pow (int, optional):
+ This sets the size of the border used for inner cuts. High cut_ic_pow values have larger borders, and
+ therefore the cuts themselves will be smaller and provide finer details. If you have too many or too-small
+ inner cuts, you may lose overall image coherency and/or it may cause an undesirable 'mosaic' effect.
+ Low cut_ic_pow values will allow the inner cuts to be larger, helping image coherency while still helping
+ with some details.
+ Default to `1`.
+ init_scale (int, optional):
+ This controls how strongly CLIP will try to match the init_image provided. This is balanced against the
+ clip_guidance_scale (CGS) above. Too much init scale, and the image won't change much during diffusion.
+ Too much CGS and the init image will be lost.
+ Default to `1000`.
+ clip_guidance_scale (int, optional):
+ CGS is one of the most important parameters you will use. It tells DD how strongly you want CLIP to move
+ toward your prompt each timestep. Higher is generally better, but if CGS is too strong it will overshoot
+ the goal and distort the image. So a happy medium is needed, and it takes experience to learn how to adjust
+ CGS. Note that this parameter generally scales with image dimensions. In other words, if you increase your
+ total dimensions by 50% (e.g. a change from 512 x 512 to 512 x 768), then to maintain the same effect on the
+ image, you'd want to increase clip_guidance_scale from 5000 to 7500. Of the basic settings, clip_guidance_scale,
+ steps and skip_steps are the most important contributors to image quality, so learn them well.
+ Default to `5000`.
+ tv_scale (int, optional):
+ Total variance denoising. Optional, set to zero to turn off. Controls smoothness of final output. If used,
+ tv_scale will try to smooth out your final image to reduce overall noise. If your image is too 'crunchy',
+ increase tv_scale. TV denoising is good at preserving edges while smoothing away noise in flat regions.
+ See https://en.wikipedia.org/wiki/Total_variation_denoising
+ Default to `0`.
+ range_scale (int, optional):
+ Optional, set to zero to turn off. Used for adjustment of color contrast. Lower range_scale will increase
+ contrast. Very low numbers create a reduced color palette, resulting in more vibrant or poster-like images.
+ Higher range_scale will reduce contrast, for more muted images.
+ Default to `0`.
+ sat_scale (int, optional):
+ Saturation scale. Optional, set to zero to turn off. If used, sat_scale will help mitigate oversaturation.
+ If your image is too saturated, increase sat_scale to reduce the saturation.
+ Default to `0`.
+ cutn_batches (int, optional):
+ Each iteration, the AI cuts the image into smaller pieces known as cuts, and compares each cut to the prompt
+ to decide how to guide the next diffusion step. More cuts can generally lead to better images, since DD has
+ more chances to fine-tune the image precision in each timestep. Additional cuts are memory intensive, however,
+ and if DD tries to evaluate too many cuts at once, it can run out of memory. You can use cutn_batches to increase
+ cuts per timestep without increasing memory usage. At the default settings, DD is scheduled to do 16 cuts per
+ timestep. If cutn_batches is set to 1, there will indeed only be 16 cuts total per timestep. However, if
+ cutn_batches is increased to 4, DD will do 64 cuts total in each timestep, divided into 4 sequential batches
+ of 16 cuts each. Because the cuts are being evaluated only 16 at a time, DD uses the memory required for only 16 cuts,
+ but gives you the quality benefit of 64 cuts. The tradeoff, of course, is that this will take ~4 times as long to
+ render each image.So, (scheduled cuts) x (cutn_batches) = (total cuts per timestep). Increasing cutn_batches will
+ increase render times, however, as the work is being done sequentially. DD's default cut schedule is a good place
+ to start, but the cut schedule can be adjusted in the Cutn Scheduling section, explained below.
+ Default to `4`.
+ perlin_init (bool, optional):
+ Normally, DD will use an image filled with random noise as a starting point for the diffusion curve.
+ If perlin_init is selected, DD will instead use a Perlin noise model as an initial state. Perlin has very
+ interesting characteristics, distinct from random noise, so it's worth experimenting with this for your projects.
+ Beyond perlin, you can, of course, generate your own noise images (such as with GIMP, etc) and use them as an
+ init_image (without skipping steps). Choosing perlin_init does not affect the actual diffusion process, just the
+ starting point for the diffusion. Please note that selecting a perlin_init will replace and override any init_image
+ you may have specified. Further, because the 2D, 3D and video animation systems all rely on the init_image system,
+ if you enable Perlin while using animation modes, the perlin_init will jump in front of any previous image or video
+ input, and DD will NOT give you the expected sequence of coherent images. All of that said, using Perlin and
+ animation modes together do make a very colorful rainbow effect, which can be used creatively.
+ Default to `False`.
+ perlin_mode (str, optional):
+ sets type of Perlin noise: colored, gray, or a mix of both, giving you additional options for noise types. Experiment
+ to see what these do in your projects.
+ Default to `mixed`.
+ seed (int, optional):
+ Deep in the diffusion code, there is a random number seed which is used as the basis for determining the initial
+ state of the diffusion. By default, this is random, but you can also specify your own seed. This is useful if you like a
+ particular result and would like to run more iterations that will be similar. After each run, the actual seed value used will be
+ reported in the parameters report, and can be reused if desired by entering seed # here. If a specific numerical seed is used
+ repeatedly, the resulting images will be quite similar but not identical.
+ Default to `None`.
+ eta (float, optional):
+ Eta (greek letter η) is a diffusion model variable that mixes in a random amount of scaled noise into each timestep.
+ 0 is no noise, 1.0 is more noise. As with most DD parameters, you can go below zero for eta, but it may give you
+ unpredictable results. The steps parameter has a close relationship with the eta parameter. If you set eta to 0,
+ then you can get decent output with only 50-75 steps. Setting eta to 1.0 favors higher step counts, ideally around
+ 250 and up. eta has a subtle, unpredictable effect on image, so you'll need to experiment to see how this affects your projects.
+ Default to `0.8`.
+ clamp_grad (bool, optional):
+ As I understand it, clamp_grad is an internal limiter that stops DD from producing extreme results. Try your images with and without
+ clamp_grad. If the image changes drastically with clamp_grad turned off, it probably means your clip_guidance_scale is too high and
+ should be reduced.
+ Default to `True`.
+ clamp_max (float, optional):
+ Sets the value of the clamp_grad limitation. Default is 0.05, providing for smoother, more muted coloration in images, but setting
+ higher values (0.15-0.3) can provide interesting contrast and vibrancy.
+ Default to `0.05`.
+ cut_overview (str, optional):
+ The schedule of overview cuts.
+ Default to `'[12]*400+[4]*600'`.
+ cut_innercut (str, optional):
+ The schedule of inner cuts.
+ Default to `'[4]*400+[12]*600'`.
+ cut_icgray_p (str, optional):
+ This sets the size of the border used for inner cuts. High cut_ic_pow values have larger borders, and therefore the cuts
+ themselves will be smaller and provide finer details. If you have too many or too-small inner cuts, you may lose overall
+ image coherency and/or it may cause an undesirable 'mosaic' effect. Low cut_ic_pow values will allow the inner cuts to be
+ larger, helping image coherency while still helping with some details.
+ Default to `'[0.2]*400+[0]*600'`.
+ save_rate (int, optional):
+ During a diffusion run, you can monitor the progress of each image being created with this variable. If display_rate is set
+ to 50, DD will show you the in-progress image every 50 timesteps. Setting this to a lower value, like 5 or 10, is a good way
+ to get an early peek at where your image is heading. If you don't like the progression, just interrupt execution, change some
+ settings, and re-run. If you are planning a long, unmonitored batch, it's better to set display_rate equal to steps, because
+ displaying interim images does slow Colab down slightly.
+ Default to `10`.
+ n_batches (int, optional):
+ This variable sets the number of still images you want DD to create. If you are using an animation mode (see below for details)
+ DD will ignore n_batches and create a single set of animated frames based on the animation settings.
+ Default to `1`.
+ batch_name (str, optional):
+ The name of the batch, the batch id will be named as "progress-[batch_name]-seed-[range(n_batches)]-[save_rate]". To avoid your
+ artworks be overridden by other users, please use a unique name.
+ Default to `''`.
+ use_secondary_model (bool, optional):
+ Whether or not use secondary model.
+ Default to `True`.
+ randomize_class (bool, optional):
+ Random class.
+ Default to `True`.
+ clip_denoised (bool, optional):
+ Clip denoised.
+ Default to `False`.
+ '''
+ output_dir = Path(output_dir)
+ output_dir.mkdir(exist_ok=True, parents=True)
+ batch_size = 1
+ normalize = Normalize(
+ mean=image_mean,
+ std=image_std,
+ )
+ side_x = (width_height[0] // 64) * 64
+ side_y = (width_height[1] // 64) * 64
+ cut_overview = eval(cut_overview)
+ cut_innercut = eval(cut_innercut)
+ cut_icgray_p = eval(cut_icgray_p)
+
+ seed = seed or random.randint(0, 2**32)
+ set_seed(seed)
+
+ init = None
+ if init_image:
+ d = load_image(init_image)
+ init = T.to_tensor(d).unsqueeze(0) * 2 - 1
+
+ if perlin_init:
+ if perlin_mode == 'color':
+ init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)],
+ 1, 1, False, side_y, side_x)
+ init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)],
+ 4, 4, False, side_y, side_x)
+ elif perlin_mode == 'gray':
+ init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)],
+ 1, 1, True, side_y, side_x)
+ init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)],
+ 4, 4, True, side_y, side_x)
+ else:
+ init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)],
+ 1, 1, False, side_y, side_x)
+ init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)],
+ 4, 4, True, side_y, side_x)
+ init = (T.to_tensor(init).add(T.to_tensor(init2)).divide(
+ paddle.to_tensor(2.0)).unsqueeze(0) * 2 - 1)
+ del init2
+
+ if init is not None and init_scale:
+ lpips = try_import("paddle_lpips")
+ lpips_model = lpips.LPIPS(net='vgg')
+ lpips_model.eval()
+ for parameter in lpips_model.parameters():
+ parameter.stop_gradient = True
+
+ cur_t = None
+
+ def cond_fn(x, t, y=None):
+ x_is_NaN = False
+ n = x.shape[0]
+ x = paddle.to_tensor(x.detach(), dtype='float32')
+ x.stop_gradient = False
+ if use_secondary_model:
+ alpha = paddle.to_tensor(
+ self.diffusion.sqrt_alphas_cumprod[cur_t], dtype='float32')
+ sigma = paddle.to_tensor(
+ self.diffusion.sqrt_one_minus_alphas_cumprod[cur_t],
+ dtype='float32')
+ cosine_t = alpha_sigma_to_t(alpha, sigma)
+ cosine_t = paddle.tile(
+ paddle.to_tensor(cosine_t.detach().cpu().numpy()), [n])
+ cosine_t.stop_gradient = False
+ out = self.secondary_model(x, cosine_t).pred
+ fac = self.diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
+ x_in_d = out * fac + x * (1 - fac)
+ x_in = x_in_d.detach()
+ x_in.stop_gradient = False
+ x_in_grad = paddle.zeros_like(x_in, dtype='float32')
+ else:
+ t = paddle.ones([n], dtype='int64') * cur_t
+ out = self.diffusion.p_mean_variance(self.unet_model,
+ x,
+ t,
+ clip_denoised=False,
+ model_kwargs={'y': y})
+ fac = self.diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
+ x_in_d = out['pred_xstart'].astype("float32") * fac + x * (1 -
+ fac)
+ x_in = x_in_d.detach()
+ x_in.stop_gradient = False
+ x_in_grad = paddle.zeros_like(x_in, dtype='float32')
+
+ for _ in range(cutn_batches):
+ t_int = (
+ int(t.item()) + 1
+ ) # errors on last step without +1, need to find source
+ # when using SLIP Base model the dimensions need to be hard coded to avoid AttributeError: 'VisionTransformer' object has no attribute 'input_resolution'
+ try:
+ input_resolution = self.vision_model.input_resolution
+ except:
+ input_resolution = 224
+
+ cuts = MakeCutoutsDango(
+ input_resolution,
+ Overview=cut_overview[1000 - t_int],
+ InnerCrop=cut_innercut[1000 - t_int],
+ IC_Size_Pow=cut_ic_pow,
+ IC_Grey_P=cut_icgray_p[1000 - t_int],
+ )
+ clip_in = normalize(
+ cuts(
+ x_in.add(paddle.to_tensor(1.0)).divide(
+ paddle.to_tensor(2.0))))
+ image_embeds = self.get_image_features(clip_in)
+
+ dists = spherical_dist_loss(
+ image_embeds.unsqueeze(1),
+ target_text_embeds.unsqueeze(0),
+ )
+
+ dists = dists.reshape([
+ cut_overview[1000 - t_int] + cut_innercut[1000 - t_int],
+ n,
+ -1,
+ ])
+ losses = dists.sum(2).mean(0)
+ x_in_grad += (
+ paddle.grad(losses.sum() * clip_guidance_scale, x_in)[0] /
+ cutn_batches)
+ tv_losses = tv_loss(x_in)
+ range_losses = range_loss(x_in)
+ sat_losses = paddle.abs(x_in - x_in.clip(min=-1, max=1)).mean()
+ loss = (tv_losses.sum() * tv_scale +
+ range_losses.sum() * range_scale +
+ sat_losses.sum() * sat_scale)
+ if init is not None and init_scale:
+ init_losses = lpips_model(x_in, init)
+ loss = loss + init_losses.sum() * init_scale
+ x_in_grad += paddle.grad(loss, x_in)[0]
+ if not paddle.isnan(x_in_grad).any():
+ grad = -paddle.grad(x_in_d, x, x_in_grad)[0]
+ else:
+ x_is_NaN = True
+ grad = paddle.zeros_like(x)
+ if clamp_grad and not x_is_NaN:
+ magnitude = grad.square().mean().sqrt()
+ return (grad * magnitude.clip(max=clamp_max) / magnitude)
+ return grad
+
+ # we use ddim sample
+ sample_fn = self.diffusion.ddim_sample_loop_progressive
+
+ da_batches = []
+
+ # process output file name
+ output_filename_list = ["progress"]
+ if batch_name != "":
+ output_filename_list.append(batch_name)
+ if seed is not None:
+ output_filename_list.append(str(seed))
+ output_filename_prefix = "-".join(output_filename_list)
+
+ for _nb in range(n_batches):
+ gc.collect()
+ paddle.device.cuda.empty_cache()
+ cur_t = self.diffusion.num_timesteps - skip_steps - 1
+
+ if perlin_init:
+ init = regen_perlin(perlin_mode, side_y, side_x, batch_size)
+
+ samples = sample_fn(
+ self.unet_model,
+ (batch_size, 3, side_y, side_x),
+ clip_denoised=clip_denoised,
+ model_kwargs={},
+ cond_fn=cond_fn,
+ progress=True,
+ skip_timesteps=skip_steps,
+ init_image=init,
+ randomize_class=randomize_class,
+ eta=eta,
+ )
+
+ for j, sample in enumerate(samples):
+ cur_t -= 1
+ if j % save_rate == 0 or cur_t == -1:
+ for b, image in enumerate(sample['pred_xstart']):
+ image = (((image + 1) / 2).clip(
+ 0, 1).squeeze().transpose([1, 2, 0]).numpy() *
+ 255).astype("uint8")
+ image = Image.fromarray(image)
+ image.save(output_dir /
+ f'{output_filename_prefix}-{_nb}-{j}.png')
+ if cur_t == -1:
+ da_batches.append(image)
+
+ return da_batches
diff --git a/paddlenlp/transformers/image_utils.py b/paddlenlp/transformers/image_utils.py
new file mode 100644
index 000000000000..4eb66338ef51
--- /dev/null
+++ b/paddlenlp/transformers/image_utils.py
@@ -0,0 +1,404 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import List, Union
+import paddle
+import numpy as np
+from PIL import Image
+import PIL.Image
+import PIL.ImageOps
+
+import requests
+
+IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
+IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
+IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5]
+IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]
+
+ImageInput = Union[PIL.Image.Image, np.ndarray, "paddle.Tensor",
+ List[PIL.Image.Image], List[np.ndarray],
+ List["paddle.Tensor"] # noqa
+ ]
+
+
+def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image":
+ """
+ Loads `image` to a PIL Image.
+ Args:
+ image (`str` or `PIL.Image.Image`):
+ The image to convert to the PIL Image format.
+ Returns:
+ `PIL.Image.Image`: A PIL Image.
+ """
+ if isinstance(image, str):
+ if image.startswith("http://") or image.startswith("https://"):
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file
+ # like http_huggingface_co.png
+ image = PIL.Image.open(requests.get(image, stream=True).raw)
+ elif os.path.isfile(image):
+ image = PIL.Image.open(image)
+ else:
+ raise ValueError(
+ f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
+ )
+ elif isinstance(image, PIL.Image.Image):
+ image = image
+ else:
+ raise ValueError(
+ "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
+ )
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+
+class ImageFeatureExtractionMixin:
+ """
+ Mixin that contain utilities for preparing image features.
+ """
+
+ def _ensure_format_supported(self, image):
+ if not isinstance(
+ image,
+ (PIL.Image.Image, np.ndarray)) and not paddle.is_tensor(image):
+ raise ValueError(
+ f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and "
+ "`paddle.Tensor` are.")
+
+ def to_pil_image(self, image, rescale=None):
+ """
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
+ needed.
+ Args:
+ image (`PIL.Image.Image` or `numpy.ndarray` or `paddle.Tensor`):
+ The image to convert to the PIL Image format.
+ rescale (`bool`, *optional*):
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
+ default to `True` if the image type is a floating type, `False` otherwise.
+ """
+ self._ensure_format_supported(image)
+
+ if paddle.is_tensor(image):
+ image = image.numpy()
+
+ if isinstance(image, np.ndarray):
+ if rescale is None:
+ # rescale default to the array being of floating type.
+ rescale = isinstance(image.flat[0], np.floating)
+ # If the channel as been moved to first dim, we put it back at the end.
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
+ image = image.transpose([1, 2, 0])
+ if rescale:
+ image = image * 255
+ image = image.astype(np.uint8)
+ return PIL.Image.fromarray(image)
+ return image
+
+ def convert_rgb(self, image):
+ """
+ Converts `PIL.Image.Image` to RGB format.
+ Args:
+ image (`PIL.Image.Image`):
+ The image to convert.
+ """
+ self._ensure_format_supported(image)
+ if not isinstance(image, PIL.Image.Image):
+ return image
+
+ return image.convert("RGB")
+
+ def to_numpy_array(self, image, rescale=None, channel_first=True):
+ """
+ Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
+ dimension.
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `paddle.Tensor`):
+ The image to convert to a NumPy array.
+ rescale (`bool`, *optional*):
+ Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
+ default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
+ channel_first (`bool`, *optional*, defaults to `True`):
+ Whether or not to permute the dimensions of the image to put the channel dimension first.
+ """
+ self._ensure_format_supported(image)
+
+ if isinstance(image, PIL.Image.Image):
+ image = np.array(image)
+
+ if paddle.is_tensor(image):
+ image = image.numpy()
+
+ if rescale is None:
+ rescale = isinstance(image.flat[0], np.integer)
+
+ if rescale:
+ image = image.astype(np.float32) / 255.0
+
+ if channel_first and image.ndim == 3:
+ image = image.transpose([2, 0, 1])
+
+ return image
+
+ def expand_dims(self, image):
+ """
+ Expands 2-dimensional `image` to 3 dimensions.
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `paddle.Tensor`):
+ The image to expand.
+ """
+ self._ensure_format_supported(image)
+
+ # Do nothing if PIL image
+ if isinstance(image, PIL.Image.Image):
+ return image
+
+ if paddle.is_tensor(image):
+ image = image.unsqueeze(0)
+ else:
+ image = np.expand_dims(image, axis=0)
+ return image
+
+ def normalize(self, image, mean, std):
+ """
+ Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
+ if it's a PIL Image.
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `paddle.Tensor`):
+ The image to normalize.
+ mean (`List[float]` or `np.ndarray` or `paddle.Tensor`):
+ The mean (per channel) to use for normalization.
+ std (`List[float]` or `np.ndarray` or `paddle.Tensor`):
+ The standard deviation (per channel) to use for normalization.
+ """
+ self._ensure_format_supported(image)
+
+ if isinstance(image, PIL.Image.Image):
+ image = self.to_numpy_array(image)
+
+ if isinstance(image, np.ndarray):
+ if not isinstance(mean, np.ndarray):
+ mean = np.array(mean).astype(image.dtype)
+ if not isinstance(std, np.ndarray):
+ std = np.array(std).astype(image.dtype)
+ elif paddle.is_tensor(image):
+ import paddle
+
+ if not isinstance(mean, paddle.Tensor):
+ mean = paddle.to_tensor(mean).astype(image.dtype)
+ if not isinstance(std, paddle.Tensor):
+ std = paddle.to_tensor(std).astype(image.dtype)
+
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
+ return (image - mean[:, None, None]) / std[:, None, None]
+ else:
+ return (image - mean) / std
+
+ def resize(self,
+ image,
+ size,
+ resample=Image.BILINEAR,
+ default_to_square=True,
+ max_size=None):
+ """
+ Resizes `image`. Enforces conversion of input to PIL.Image.
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `paddle.Tensor`):
+ The image to resize.
+ size (`int` or `Tuple[int, int]`):
+ The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
+ matched to this.
+ If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
+ `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
+ this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
+ resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
+ The filter to user for resampling.
+ default_to_square (`bool`, *optional*, defaults to `True`):
+ How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
+ square (`size`,`size`). If set to `False`, will replicate
+ [`paddle.vision.transforms.Resize`](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/vision/transforms/Resize_cn.html#resize)
+ with support for resizing only the smallest edge and providing an optional `max_size`.
+ max_size (`int`, *optional*, defaults to `None`):
+ The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
+ greater than `max_size` after being resized according to `size`, then the image is resized again so
+ that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
+ edge may be shorter than `size`. Only used if `default_to_square` is `False`.
+ Returns:
+ image: A resized `PIL.Image.Image`.
+ """
+ self._ensure_format_supported(image)
+
+ if not isinstance(image, PIL.Image.Image):
+ image = self.to_pil_image(image)
+
+ if isinstance(size, list):
+ size = tuple(size)
+
+ if isinstance(size, int) or len(size) == 1:
+ if default_to_square:
+ size = (size, size) if isinstance(size, int) else (size[0],
+ size[0])
+ else:
+ width, height = image.size
+ # specified size only for the smallest edge
+ short, long = (width, height) if width <= height else (height,
+ width)
+ requested_new_short = size if isinstance(size, int) else size[0]
+
+ if short == requested_new_short:
+ return image
+
+ new_short, new_long = requested_new_short, int(
+ requested_new_short * long / short)
+
+ if max_size is not None:
+ if max_size <= requested_new_short:
+ raise ValueError(
+ f"max_size = {max_size} must be strictly greater than the requested "
+ f"size for the smaller edge size = {size}")
+ if new_long > max_size:
+ new_short, new_long = int(max_size * new_short /
+ new_long), max_size
+
+ size = (new_short, new_long) if width <= height else (new_long,
+ new_short)
+
+ return image.resize(size, resample=resample)
+
+ def center_crop(self, image, size):
+ """
+ Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
+ size given, it will be padded (so the returned result has the size asked).
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `paddle.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
+ The image to resize.
+ size (`int` or `Tuple[int, int]`):
+ The size to which crop the image.
+ Returns:
+ new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `paddle.Tensor` of shape: (n_channels,
+ height, width).
+ """
+ self._ensure_format_supported(image)
+
+ if not isinstance(size, tuple):
+ size = (size, size)
+
+ # PIL Image.size is (width, height) but NumPy array and paddle Tensors have (height, width)
+ if paddle.is_tensor(image) or isinstance(image, np.ndarray):
+ if image.ndim == 2:
+ image = self.expand_dims(image)
+ image_shape = image.shape[1:] if image.shape[0] in [
+ 1, 3
+ ] else image.shape[:2]
+ else:
+ image_shape = (image.size[1], image.size[0])
+
+ top = (image_shape[0] - size[0]) // 2
+ bottom = top + size[
+ 0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
+ left = (image_shape[1] - size[1]) // 2
+ right = left + size[
+ 1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
+
+ # For PIL Images we have a method to crop directly.
+ if isinstance(image, PIL.Image.Image):
+ return image.crop((left, top, right, bottom))
+
+ # Check if image is in (n_channels, height, width) or (height, width, n_channels) format
+ channel_first = True if image.shape[0] in [1, 3] else False
+
+ # Transpose (height, width, n_channels) format images
+ if not channel_first:
+ if isinstance(image, np.ndarray):
+ image = image.transpose([2, 0, 1])
+ if paddle.is_tensor(image):
+ image = image.transpose([2, 0, 1])
+
+ # Check if cropped area is within image boundaries
+ if top >= 0 and bottom <= image_shape[
+ 0] and left >= 0 and right <= image_shape[1]:
+ return image[..., top:bottom, left:right]
+
+ # Otherwise, we may need to pad if the image is too small. Oh joy...
+ new_shape = image.shape[:-2] + (max(
+ size[0], image_shape[0]), max(size[1], image_shape[1]))
+ if isinstance(image, np.ndarray):
+ new_image = np.zeros_like(image, shape=new_shape)
+ elif paddle.is_tensor(image):
+ new_image = paddle.zeros(new_shape, dtype=image.dtype)
+
+ top_pad = (new_shape[-2] - image_shape[0]) // 2
+ bottom_pad = top_pad + image_shape[0]
+ left_pad = (new_shape[-1] - image_shape[1]) // 2
+ right_pad = left_pad + image_shape[1]
+ new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
+
+ top += top_pad
+ bottom += top_pad
+ left += left_pad
+ right += left_pad
+
+ new_image = new_image[...,
+ max(0, top):min(new_image.shape[-2], bottom),
+ max(0, left):min(new_image.shape[-1], right)]
+
+ return new_image
+
+ def flip_channel_order(self, image):
+ """
+ Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
+ `image` to a NumPy array if it's a PIL Image.
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `paddle.Tensor`):
+ The image whose color channels to flip. If `np.ndarray` or `paddle.Tensor`, the channel dimension should
+ be first.
+ """
+ self._ensure_format_supported(image)
+
+ if isinstance(image, PIL.Image.Image):
+ image = self.to_numpy_array(image)
+
+ return image[::-1, :, :]
+
+ def rotate(self,
+ image,
+ angle,
+ resample=Image.NEAREST,
+ expand=0,
+ center=None,
+ translate=None,
+ fillcolor=None):
+ """
+ Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
+ counter clockwise around its centre.
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `paddle.Tensor`):
+ The image to rotate. If `np.ndarray` or `paddle.Tensor`, will be converted to `PIL.Image.Image` before
+ rotating.
+ Returns:
+ image: A rotated `PIL.Image.Image`.
+ """
+ self._ensure_format_supported(image)
+
+ if not isinstance(image, PIL.Image.Image):
+ image = self.to_pil_image(image)
+
+ return image.rotate(angle,
+ resample=resample,
+ expand=expand,
+ center=center,
+ translate=translate,
+ fillcolor=fillcolor)
diff --git a/paddlenlp/transformers/stable_diffusion_utils/__init__.py b/paddlenlp/transformers/stable_diffusion_utils/__init__.py
new file mode 100644
index 000000000000..e03311d4453f
--- /dev/null
+++ b/paddlenlp/transformers/stable_diffusion_utils/__init__.py
@@ -0,0 +1,5 @@
+from .unet_2d_condition import UNet2DConditionModel
+from .vae import AutoencoderKL
+from .schedulers import (LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler,
+ SchedulerMixin)
+from .utils import StableDiffusionMixin
diff --git a/paddlenlp/transformers/stable_diffusion_utils/attention.py b/paddlenlp/transformers/stable_diffusion_utils/attention.py
new file mode 100644
index 000000000000..caad51663b36
--- /dev/null
+++ b/paddlenlp/transformers/stable_diffusion_utils/attention.py
@@ -0,0 +1,299 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import numpy as np
+
+
+def finfo(dtype):
+ if dtype == paddle.float32:
+ return np.finfo(np.float32)
+ if dtype == paddle.float16:
+ return np.finfo(np.float16)
+ if dtype == paddle.float64:
+ return np.finfo(np.float64)
+
+
+class AttentionBlock(nn.Layer):
+ """
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
+ to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ Uses three q, k, v linear layers to compute attention
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_head_channels=None,
+ num_groups=32,
+ rescale_output_factor=1.0,
+ eps=1e-5,
+ ):
+ super().__init__()
+ self.channels = channels
+
+ self.num_heads = (channels // num_head_channels
+ if num_head_channels is not None else 1)
+ self.num_head_size = num_head_channels
+ self.group_norm = nn.GroupNorm(num_channels=channels,
+ num_groups=num_groups,
+ epsilon=eps)
+
+ # define q,k,v as linear layers
+ self.query = nn.Linear(channels, channels)
+ self.key = nn.Linear(channels, channels)
+ self.value = nn.Linear(channels, channels)
+
+ self.rescale_output_factor = rescale_output_factor
+ self.proj_attn = nn.Linear(channels, channels)
+
+ def transpose_for_scores(self, projection: paddle.Tensor) -> paddle.Tensor:
+ new_projection_shape = projection.shape[:-1] + [self.num_heads, -1]
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
+ new_projection = projection.reshape(new_projection_shape).transpose(
+ [0, 2, 1, 3])
+ return new_projection
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ batch, channel, height, width = hidden_states.shape
+
+ # norm
+ hidden_states = self.group_norm(hidden_states)
+
+ hidden_states = hidden_states.reshape([batch, channel, height * width
+ ]).transpose([0, 2, 1])
+
+ # proj to q, k, v
+ query_proj = self.query(hidden_states)
+ key_proj = self.key(hidden_states)
+ value_proj = self.value(hidden_states)
+
+ # transpose
+ query_states = self.transpose_for_scores(query_proj)
+ key_states = self.transpose_for_scores(key_proj)
+ value_states = self.transpose_for_scores(value_proj)
+
+ # get scores
+ scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
+ attention_scores = paddle.matmul(query_states * scale,
+ key_states * scale,
+ transpose_y=True)
+ attention_probs = F.softmax(attention_scores.astype("float32"),
+ axis=-1).astype(attention_scores.dtype)
+
+ # compute attention output
+ context_states = paddle.matmul(attention_probs, value_states)
+
+ context_states = context_states.transpose([0, 2, 1, 3])
+ new_context_states_shape = context_states.shape[:-2] + [
+ self.channels,
+ ]
+ context_states = context_states.reshape(new_context_states_shape)
+
+ # compute next hidden_states
+ hidden_states = self.proj_attn(context_states)
+ hidden_states = hidden_states.transpose([0, 2, 1]).reshape(
+ [batch, channel, height, width])
+
+ # res connect and rescale
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
+ return hidden_states
+
+
+class SpatialTransformer(nn.Layer):
+ """
+ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
+ standard transformer action. Finally, reshape to image
+ """
+
+ def __init__(self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ context_dim=None):
+ super().__init__()
+ self.n_heads = n_heads
+ self.d_head = d_head
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = nn.GroupNorm(num_groups=32,
+ num_channels=in_channels,
+ epsilon=1e-6)
+
+ self.proj_in = nn.Conv2D(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.transformer_blocks = nn.LayerList([
+ BasicTransformerBlock(inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim) for d in range(depth)
+ ])
+
+ self.proj_out = nn.Conv2D(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = x.transpose([0, 2, 3, 1]).reshape([b, h * w, c])
+ for block in self.transformer_blocks:
+ x = block(x, context=context)
+ x = x.reshape([b, h, w, c]).transpose([0, 3, 1, 2])
+ x = self.proj_out(x)
+ return x + x_in
+
+
+class BasicTransformerBlock(nn.Layer):
+
+ def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None):
+ super().__init__()
+ self.attn1 = CrossAttention(query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout)
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ ) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+
+ def forward(self, x, context=None):
+ x = self.attn1(self.norm1(x)) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class CrossAttention(nn.Layer):
+
+ def __init__(self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.0):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = context_dim if context_dim is not None else query_dim
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias_attr=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias_attr=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias_attr=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout))
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(
+ [batch_size, seq_len, head_size, dim // head_size])
+ tensor = tensor.transpose([0, 2, 1, 3]).reshape(
+ [batch_size * head_size, seq_len, dim // head_size])
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(
+ [batch_size // head_size, head_size, seq_len, dim])
+ tensor = tensor.transpose([0, 2, 1, 3]).reshape(
+ [batch_size // head_size, seq_len, dim * head_size])
+ return tensor
+
+ def forward(self, x, context=None, mask=None):
+ batch_size, sequence_length, dim = x.shape
+
+ h = self.heads
+
+ q = self.to_q(x)
+ context = context if context is not None else x
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q = self.reshape_heads_to_batch_dim(q)
+ k = self.reshape_heads_to_batch_dim(k)
+ v = self.reshape_heads_to_batch_dim(v)
+
+ # sim = paddle.einsum("b i d, b j d -> b i j", q, k) * self.scale
+ sim = paddle.einsum("b i d, b j d -> b i j", q * self.scale, k)
+
+ if mask is not None:
+ mask = mask.reshape([batch_size, -1])
+ max_neg_value = -finfo(sim.dtype).max
+ mask = mask[:, None, :].expand([h, -1, -1]).astype(sim.dtype)
+ sim = sim * mask + (1 - mask) * max_neg_value
+
+ # attention, what we cannot get enough of
+ attn = F.softmax(sim, axis=-1)
+
+ out = paddle.einsum("b i j, b j d -> b i d", attn, v)
+ out = self.reshape_batch_dim_to_heads(out)
+ return self.to_out(out)
+
+
+class FeedForward(nn.Layer):
+
+ def __init__(self, dim, dim_out=None, mult=4, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+ project_in = GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out))
+
+ def forward(self, x):
+ return self.net(x)
+
+
+# feedforward
+class GEGLU(nn.Layer):
+
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, axis=-1)
+ return x * F.gelu(gate)
diff --git a/paddlenlp/transformers/stable_diffusion_utils/embeddings.py b/paddlenlp/transformers/stable_diffusion_utils/embeddings.py
new file mode 100644
index 000000000000..f2767cb0306e
--- /dev/null
+++ b/paddlenlp/transformers/stable_diffusion_utils/embeddings.py
@@ -0,0 +1,120 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+
+
+def get_timestep_embedding(
+ timesteps,
+ embedding_dim,
+ flip_sin_to_cos=False,
+ downscale_freq_shift=1,
+ scale=1,
+ max_period=10000,
+):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * paddle.arange(
+ start=0, end=half_dim, dtype="float32")
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = paddle.exp(exponent)
+ emb = timesteps[:, None].astype("float32") * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = paddle.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = paddle.concat(emb, paddle.zeros([emb.shape[0], 1]), axis=-1)
+ return emb
+
+
+class TimestepEmbedding(nn.Layer):
+
+ def __init__(self, channel, time_embed_dim, act_fn="silu"):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(channel, time_embed_dim)
+ self.act = None
+ if act_fn == "silu":
+ self.act = nn.Silu()
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
+
+ def forward(self, sample):
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+ return sample
+
+
+class Timesteps(nn.Layer):
+
+ def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ )
+ return t_emb
+
+
+class GaussianFourierProjection(nn.Layer):
+ """Gaussian Fourier embeddings for noise levels."""
+
+ def __init__(self, embedding_size=256, scale=1.0):
+ super().__init__()
+ self.register_buffer("weight", paddle.randn((embedding_size, )) * scale)
+
+ # to delete later
+ self.register_buffer("W", paddle.randn((embedding_size, )) * scale)
+
+ self.weight = self.W
+
+ def forward(self, x):
+ x = paddle.log(x)
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
+ out = paddle.concat([paddle.sin(x_proj), paddle.cos(x_proj)], axis=-1)
+ return out
diff --git a/paddlenlp/transformers/stable_diffusion_utils/resnet.py b/paddlenlp/transformers/stable_diffusion_utils/resnet.py
new file mode 100644
index 000000000000..df1e1bb5f7df
--- /dev/null
+++ b/paddlenlp/transformers/stable_diffusion_utils/resnet.py
@@ -0,0 +1,611 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+def pad_new(x, pad, mode="constant", value=0):
+ new_pad = []
+ for _ in range(x.ndim * 2 - len(pad)):
+ new_pad.append(0)
+ ndim = list(range(x.ndim - 1, 0, -1))
+ axes_start = {}
+ for i, _pad in enumerate(pad):
+ if _pad < 0:
+ new_pad.append(0)
+ zhengshu, yushu = divmod(i, 2)
+ if yushu == 0:
+ axes_start[ndim[zhengshu]] = -_pad
+ else:
+ new_pad.append(_pad)
+
+ padded = paddle.nn.functional.pad(x, new_pad, mode=mode, value=value)
+ padded_shape = paddle.shape(padded)
+ axes = []
+ starts = []
+ ends = []
+ for k, v in axes_start.items():
+ axes.append(k)
+ starts.append(v)
+ ends.append(padded_shape[k])
+ assert v < padded_shape[k]
+
+ if axes:
+ return padded.slice(axes=axes, starts=starts, ends=ends)
+ else:
+ return padded
+
+
+class Upsample2D(nn.Layer):
+ """
+ An upsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(
+ self,
+ channels,
+ use_conv=False,
+ use_conv_transpose=False,
+ out_channels=None,
+ name="conv",
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ conv = None
+ if use_conv_transpose:
+ conv = nn.Conv2DTranspose(channels, self.out_channels, 4, 2, 1)
+ elif use_conv:
+ conv = nn.Conv2D(self.channels, self.out_channels, 3, padding=1)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.use_conv_transpose:
+ return self.conv(x)
+
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if self.use_conv:
+ if self.name == "conv":
+ x = self.conv(x)
+ else:
+ x = self.Conv2d_0(x)
+
+ return x
+
+
+class Downsample2D(nn.Layer):
+ """
+ A downsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self,
+ channels,
+ use_conv=False,
+ out_channels=None,
+ padding=1,
+ name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ conv = nn.Conv2D(self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=padding)
+ else:
+ assert self.channels == self.out_channels
+ conv = nn.AvgPool2D(kernel_size=stride, stride=stride)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ pad = (0, 1, 0, 1)
+ x = pad_new(x, pad, mode="constant", value=0)
+
+ assert x.shape[1] == self.channels
+ x = self.conv(x)
+
+ return x
+
+
+class FirUpsample2D(nn.Layer):
+
+ def __init__(self,
+ channels=None,
+ out_channels=None,
+ use_conv=False,
+ fir_kernel=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = nn.Conv2D(channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.use_conv = use_conv
+ self.fir_kernel = fir_kernel
+ self.out_channels = out_channels
+
+ def _upsample_2d(self, x, w=None, k=None, factor=2, gain=1):
+ """Fused `upsample_2d()` followed by `Conv2d()`.
+
+ Args:
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
+ order.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ w: Weight tensor of the shape `[filterH, filterW, inChannels,
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
+ `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+
+ # Setup filter kernel.
+ if k is None:
+ k = [1] * factor
+
+ # setup kernel
+ k = np.asarray(k, dtype=np.float32)
+ if k.ndim == 1:
+ k = np.outer(k, k)
+ k /= np.sum(k)
+
+ k = k * (gain * (factor**2))
+
+ if self.use_conv:
+ convH = w.shape[2]
+ convW = w.shape[3]
+ inC = w.shape[1]
+
+ p = (k.shape[0] - factor) - (convW - 1)
+
+ stride = (factor, factor)
+ # Determine data dimensions.
+ stride = [1, 1, factor, factor]
+ output_shape = (
+ (x.shape[2] - 1) * factor + convH,
+ (x.shape[3] - 1) * factor + convW,
+ )
+ output_padding = (
+ output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
+ output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
+ )
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
+ inC = w.shape[1]
+ num_groups = x.shape[1] // inC
+
+ # Transpose weights.
+ w = paddle.reshape(w, (num_groups, -1, inC, convH, convW))
+ w = w[..., ::-1, ::-1].transpose([0, 2, 1, 3, 4])
+ w = paddle.reshape(w, (num_groups * inC, -1, convH, convW))
+
+ x = F.conv2d_transpose(x,
+ w,
+ stride=stride,
+ output_padding=output_padding,
+ padding=0)
+
+ x = upfirdn2d_native(x,
+ paddle.to_tensor(k),
+ pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
+ else:
+ p = k.shape[0] - factor
+ x = upfirdn2d_native(
+ x,
+ paddle.to_tensor(k),
+ up=factor,
+ pad=((p + 1) // 2 + factor - 1, p // 2),
+ )
+
+ return x
+
+ def forward(self, x):
+ if self.use_conv:
+ h = self._upsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
+ h = h + self.Conv2d_0.bias.reshape([1, -1, 1, 1])
+ else:
+ h = self._upsample_2d(x, k=self.fir_kernel, factor=2)
+
+ return h
+
+
+class FirDownsample2D(nn.Layer):
+
+ def __init__(self,
+ channels=None,
+ out_channels=None,
+ use_conv=False,
+ fir_kernel=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = nn.Conv2D(channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.fir_kernel = fir_kernel
+ self.use_conv = use_conv
+ self.out_channels = out_channels
+
+ def _downsample_2d(self, x, w=None, k=None, factor=2, gain=1):
+ """Fused `Conv2d()` followed by `downsample_2d()`.
+
+ Args:
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
+ order.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
+ filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
+ numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
+ factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
+ Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
+ datatype as `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ if k is None:
+ k = [1] * factor
+
+ # setup kernel
+ k = np.asarray(k, dtype=np.float32)
+ if k.ndim == 1:
+ k = np.outer(k, k)
+ k /= np.sum(k)
+
+ k = k * gain
+
+ if self.use_conv:
+ _, _, convH, convW = w.shape
+ p = (k.shape[0] - factor) + (convW - 1)
+ s = [factor, factor]
+ x = upfirdn2d_native(x,
+ paddle.to_tensor(k),
+ pad=((p + 1) // 2, p // 2))
+ x = F.conv2d(x, w, stride=s, padding=0)
+ else:
+ p = k.shape[0] - factor
+ x = upfirdn2d_native(x,
+ paddle.to_tensor(k),
+ down=factor,
+ pad=((p + 1) // 2, p // 2))
+
+ return x
+
+ def forward(self, x):
+ if self.use_conv:
+ x = self._downsample_2d(x,
+ w=self.Conv2d_0.weight,
+ k=self.fir_kernel)
+ x = x + self.Conv2d_0.bias.reshape([1, -1, 1, 1])
+ else:
+ x = self._downsample_2d(x, k=self.fir_kernel, factor=2)
+
+ return x
+
+
+class ResnetBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default",
+ kernel=None,
+ output_scale_factor=1.0,
+ use_nin_shortcut=None,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.up = up
+ self.down = down
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = nn.GroupNorm(num_groups=groups,
+ num_channels=in_channels,
+ epsilon=eps)
+
+ self.conv1 = nn.Conv2D(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ if temb_channels is not None:
+ self.time_emb_proj = nn.Linear(temb_channels, out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = nn.GroupNorm(num_groups=groups_out,
+ num_channels=out_channels,
+ epsilon=eps)
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = nn.Conv2D(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.Silu()
+
+ self.upsample = self.downsample = None
+ if self.up:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
+ elif kernel == "sde_vp":
+ self.upsample = partial(F.interpolate,
+ scale_factor=2.0,
+ mode="nearest")
+ else:
+ self.upsample = Upsample2D(in_channels, use_conv=False)
+ elif self.down:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
+ elif kernel == "sde_vp":
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
+ else:
+ self.downsample = Downsample2D(in_channels,
+ use_conv=False,
+ padding=1,
+ name="op")
+
+ self.use_nin_shortcut = (self.in_channels != self.out_channels if
+ use_nin_shortcut is None else use_nin_shortcut)
+
+ self.conv_shortcut = None
+ if self.use_nin_shortcut:
+ self.conv_shortcut = nn.Conv2D(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb, hey=False):
+ h = x
+
+ # make sure hidden states is in float32
+ # when running in half-precision
+ h = self.norm1(h.astype("float32")).astype(h.dtype)
+ h = self.nonlinearity(h)
+
+ if self.upsample is not None:
+ x = self.upsample(x)
+ h = self.upsample(h)
+ elif self.downsample is not None:
+ x = self.downsample(x)
+ h = self.downsample(h)
+
+ h = self.conv1(h)
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
+ h = h + temb
+
+ # make sure hidden states is in float32
+ # when running in half-precision
+ h = self.norm2(h.astype("float32")).astype(h.dtype)
+ h = self.nonlinearity(h)
+
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.conv_shortcut is not None:
+ x = self.conv_shortcut(x)
+
+ out = (x + h) / self.output_scale_factor
+
+ return out
+
+
+class Mish(nn.Layer):
+
+ def forward(self, x):
+ return x * F.tanh(F.softplus(x))
+
+
+def upsample_2d(x, k=None, factor=2, gain=1):
+ r"""Upsample2D a batch of 2D images with the given filter.
+
+ Args:
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
+ multiple of the upsampling factor.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H * factor, W * factor]`
+ """
+ assert isinstance(factor, int) and factor >= 1
+ if k is None:
+ k = [1] * factor
+
+ k = np.asarray(k, dtype=np.float32)
+ if k.ndim == 1:
+ k = np.outer(k, k)
+ k /= np.sum(k)
+
+ k = k * (gain * (factor**2))
+ p = k.shape[0] - factor
+ return upfirdn2d_native(x,
+ paddle.to_tensor(k),
+ up=factor,
+ pad=((p + 1) // 2 + factor - 1, p // 2))
+
+
+def downsample_2d(x, k=None, factor=2, gain=1):
+ r"""Downsample2D a batch of 2D images with the given filter.
+
+ Args:
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
+ shape is a multiple of the downsampling factor.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
+ factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H // factor, W // factor]`
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ if k is None:
+ k = [1] * factor
+
+ k = np.asarray(k, dtype=np.float32)
+ if k.ndim == 1:
+ k = np.outer(k, k)
+ k /= np.sum(k)
+
+ k = k * gain
+ p = k.shape[0] - factor
+ return upfirdn2d_native(x,
+ paddle.to_tensor(k),
+ down=factor,
+ pad=((p + 1) // 2, p // 2))
+
+
+def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
+ up_x = up_y = up
+ down_x = down_y = down
+ pad_x0 = pad_y0 = pad[0]
+ pad_x1 = pad_y1 = pad[1]
+
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape([-1, in_h, in_w, 1])
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.reshape([-1, in_h, 1, in_w, 1, minor])
+ # TODO
+ out = pad_new(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.reshape([-1, in_h * up_y, in_w * up_x, minor])
+
+ out = pad_new(
+ out,
+ [0, 0,
+ max(pad_x0, 0),
+ max(pad_x1, 0),
+ max(pad_y0, 0),
+ max(pad_y1, 0)])
+ out = out[:,
+ max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
+
+ out = out.transpose([0, 3, 1, 2])
+ out = out.reshape(
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = paddle.flip(kernel, [0, 1]).reshape([1, 1, kernel_h, kernel_w])
+ out = F.conv2d(out, w)
+ out = out.reshape([
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ ])
+ out = out.transpose([0, 2, 3, 1])
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.reshape([-1, channel, out_h, out_w])
diff --git a/paddlenlp/transformers/stable_diffusion_utils/schedulers.py b/paddlenlp/transformers/stable_diffusion_utils/schedulers.py
new file mode 100644
index 000000000000..9e419c8cc807
--- /dev/null
+++ b/paddlenlp/transformers/stable_diffusion_utils/schedulers.py
@@ -0,0 +1,703 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Union
+import math
+import numpy as np
+import paddle
+from scipy import integrate
+
+__all__ = [
+ "SchedulerMixin",
+ "DDIMScheduler",
+ "LMSDiscreteScheduler",
+ "PNDMScheduler",
+]
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
+ from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2)**2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas, dtype=np.float32)
+
+
+class SchedulerMixin:
+
+ def set_format(self, tensor_format="pd"):
+ self.tensor_format = tensor_format
+ if tensor_format == "pd":
+ for key, value in vars(self).items():
+ if isinstance(value, np.ndarray):
+ setattr(self, key, paddle.to_tensor(value))
+
+ return self
+
+ def clip(self, tensor, min_value=None, max_value=None):
+ tensor_format = getattr(self, "tensor_format", "pd")
+
+ if tensor_format == "np":
+ return np.clip(tensor, min_value, max_value)
+ elif tensor_format == "pd":
+ return paddle.clip(tensor, min_value, max_value)
+
+ raise ValueError(
+ f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def log(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pd")
+
+ if tensor_format == "np":
+ return np.log(tensor)
+ elif tensor_format == "pd":
+ return paddle.log(tensor)
+
+ raise ValueError(
+ f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def match_shape(
+ self,
+ values: Union[np.ndarray, paddle.Tensor],
+ broadcast_array: Union[np.ndarray, paddle.Tensor],
+ ):
+ """
+ Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
+
+ Args:
+ values: an array or tensor of values to extract.
+ broadcast_array: an array with a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ Returns:
+ a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+
+ tensor_format = getattr(self, "tensor_format", "pd")
+ values = values.flatten()
+
+ while len(values.shape) < len(broadcast_array.shape):
+ values = values[..., None]
+
+ return values
+
+ def norm(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pd")
+ if tensor_format == "np":
+ return np.linalg.norm(tensor)
+ elif tensor_format == "pd":
+ return paddle.norm(tensor.reshape([tensor.shape[0], -1]),
+ axis=-1).mean()
+
+ raise ValueError(
+ f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def randn_like(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pd")
+ if tensor_format == "np":
+ return np.random.randn(np.shape(tensor))
+ elif tensor_format == "pd":
+ return paddle.randn(tensor.shape)
+
+ raise ValueError(
+ f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def zeros_like(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pd")
+ if tensor_format == "np":
+ return np.zeros_like(tensor)
+ elif tensor_format == "pd":
+ return paddle.zeros_like(tensor)
+
+ raise ValueError(
+ f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+
+class LMSDiscreteScheduler(SchedulerMixin):
+
+ def __init__(
+ self,
+ num_train_timesteps=1000,
+ beta_start=0.0001,
+ beta_end=0.02,
+ beta_schedule="linear",
+ trained_betas=None,
+ timestep_values=None,
+ tensor_format="pd",
+ ):
+ """
+ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
+ Katherine Crowson:
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
+ """
+ self.num_train_timesteps = num_train_timesteps
+ self.beta_start = beta_start
+ self.beta_end = beta_end
+ self.beta_schedule = beta_schedule
+ self.trained_betas = trained_betas
+ self.timestep_values = timestep_values
+ self.tensor_format = tensor_format
+
+ if beta_schedule == "linear":
+ self.betas = np.linspace(beta_start,
+ beta_end,
+ num_train_timesteps,
+ dtype=np.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (np.linspace(
+ beta_start**0.5,
+ beta_end**0.5,
+ num_train_timesteps,
+ dtype=np.float32,
+ )**2)
+ else:
+ raise NotImplementedError(
+ f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod)**0.5
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+ self.derivatives = []
+ self.set_format(tensor_format=tensor_format)
+
+ def get_lms_coefficient(self, order, t, current_order):
+ """
+ Compute a linear multistep coefficient
+ """
+
+ def lms_derivative(tau):
+ prod = 1.0
+ for k in range(order):
+ if current_order == k:
+ continue
+ prod *= (tau - self.sigmas[t - k]) / (
+ self.sigmas[t - current_order] - self.sigmas[t - k])
+ return prod
+
+ integrated_coeff = integrate.quad(lms_derivative,
+ self.sigmas[t],
+ self.sigmas[t + 1],
+ epsrel=1e-4)[0]
+
+ return integrated_coeff
+
+ def set_timesteps(self, num_inference_steps):
+ self.num_inference_steps = num_inference_steps
+ self.timesteps = np.linspace(self.num_train_timesteps - 1,
+ 0,
+ num_inference_steps,
+ dtype=float)
+
+ low_idx = np.floor(self.timesteps).astype(int)
+ high_idx = np.ceil(self.timesteps).astype(int)
+ frac = np.mod(self.timesteps, 1.0)
+ sigmas = np.array(
+ ((1 - self.alphas_cumprod) / self.alphas_cumprod)**0.5)
+ sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
+ self.sigmas = np.concatenate([sigmas, [0.0]])
+
+ self.derivatives = []
+ self.set_format(tensor_format=self.tensor_format)
+
+ def step(
+ self,
+ model_output: Union[paddle.Tensor, np.ndarray],
+ timestep: int,
+ sample: Union[paddle.Tensor, np.ndarray],
+ order: int = 4,
+ ):
+ sigma = self.sigmas[timestep]
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ pred_original_sample = sample - sigma * model_output
+
+ # 2. Convert to an ODE derivative
+ derivative = (sample - pred_original_sample) / sigma
+ self.derivatives.append(derivative)
+ if len(self.derivatives) > order:
+ self.derivatives.pop(0)
+
+ # 3. Compute linear multistep coefficients
+ order = min(timestep + 1, order)
+ lms_coeffs = [
+ self.get_lms_coefficient(order, timestep, curr_order)
+ for curr_order in range(order)
+ ]
+
+ # 4. Compute previous sample based on the derivatives path
+ prev_sample = sample + sum(coeff * derivative
+ for coeff, derivative in zip(
+ lms_coeffs, reversed(self.derivatives)))
+
+ return {"prev_sample": prev_sample}
+
+ def add_noise(self, original_samples, noise, timesteps):
+ sigmas = self.match_shape(self.sigmas[timesteps], noise)
+ noisy_samples = original_samples + noise * sigmas
+ return noisy_samples
+
+ def __len__(self):
+ return self.num_train_timesteps
+
+
+class PNDMScheduler(SchedulerMixin):
+
+ def __init__(
+ self,
+ num_train_timesteps=1000,
+ beta_start=0.0001,
+ beta_end=0.02,
+ beta_schedule="linear",
+ skip_prk_steps=False,
+ tensor_format="pd",
+ ):
+ self.num_train_timesteps = num_train_timesteps
+ self.beta_start = beta_start
+ self.beta_end = beta_end
+ self.beta_schedule = beta_schedule
+ self.skip_prk_steps = skip_prk_steps
+ self.tensor_format = tensor_format
+
+ if beta_schedule == "linear":
+ self.betas = np.linspace(beta_start,
+ beta_end,
+ num_train_timesteps,
+ dtype=np.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (np.linspace(
+ beta_start**0.5,
+ beta_end**0.5,
+ num_train_timesteps,
+ dtype=np.float32,
+ )**2)
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(
+ f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+
+ self.one = np.array(1.0)
+
+ # For now we only support F-PNDM, i.e. the runge-kutta method
+ # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
+ # mainly at formula (9), (12), (13) and the Algorithm 2.
+ self.pndm_order = 4
+
+ # running values
+ self.cur_model_output = 0
+ self.counter = 0
+ self.cur_sample = None
+ self.ets = []
+
+ # setable values
+ self.num_inference_steps = None
+ self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+ self._offset = 0
+ self.prk_timesteps = None
+ self.plms_timesteps = None
+ self.timesteps = None
+ self.set_format(tensor_format=tensor_format)
+
+ def set_timesteps(self, num_inference_steps, offset=0):
+ self.num_inference_steps = num_inference_steps
+ self._timesteps = list(
+ range(
+ 0,
+ self.num_train_timesteps,
+ self.num_train_timesteps // num_inference_steps,
+ ))
+ self._offset = offset
+ self._timesteps = np.array([t + self._offset for t in self._timesteps])
+
+ if self.skip_prk_steps:
+ # for some models like stable diffusion the prk steps can/should be skipped to
+ # produce better results. When using PNDM with `self.skip_prk_steps` the implementation
+ # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
+ self.prk_timesteps = np.array([])
+ self.plms_timesteps = np.concatenate([
+ self._timesteps[:-1], self._timesteps[-2:-1],
+ self._timesteps[-1:]
+ ])[::-1].copy()
+ else:
+ prk_timesteps = np.array(
+ self._timesteps[-self.pndm_order:]).repeat(2) + np.tile(
+ np.array([
+ 0, self.num_train_timesteps // num_inference_steps // 2
+ ]),
+ self.pndm_order,
+ )
+ self.prk_timesteps = (
+ prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
+ self.plms_timesteps = self._timesteps[:-3][::-1].copy(
+ ) # we copy to avoid having negative strides which are not supported by paddle.to_tensor
+
+ self.timesteps = np.concatenate(
+ [self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
+
+ self.ets = []
+ self.counter = 0
+ self.set_format(tensor_format=self.tensor_format)
+
+ def step(
+ self,
+ model_output: Union[paddle.Tensor, np.ndarray],
+ timestep: int,
+ sample: Union[paddle.Tensor, np.ndarray],
+ ):
+ if self.counter < len(self.prk_timesteps) and not self.skip_prk_steps:
+ return self.step_prk(model_output=model_output,
+ timestep=timestep,
+ sample=sample)
+ else:
+ return self.step_plms(model_output=model_output,
+ timestep=timestep,
+ sample=sample)
+
+ def step_prk(
+ self,
+ model_output: Union[paddle.Tensor, np.ndarray],
+ timestep: int,
+ sample: Union[paddle.Tensor, np.ndarray],
+ ):
+ """
+ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
+ solution to the differential equation.
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+ diff_to_prev = (0 if self.counter % 2 else self.num_train_timesteps //
+ self.num_inference_steps // 2)
+ prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
+ timestep = self.prk_timesteps[self.counter // 4 * 4]
+
+ if self.counter % 4 == 0:
+ self.cur_model_output += 1 / 6 * model_output
+ self.ets.append(model_output)
+ self.cur_sample = sample
+ elif (self.counter - 1) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 2) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 3) % 4 == 0:
+ model_output = self.cur_model_output + 1 / 6 * model_output
+ self.cur_model_output = 0
+
+ # cur_sample should not be `None`
+ cur_sample = self.cur_sample if self.cur_sample is not None else sample
+
+ prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep,
+ model_output)
+ self.counter += 1
+
+ return {"prev_sample": prev_sample}
+
+ def step_plms(
+ self,
+ model_output: Union[paddle.Tensor, np.ndarray],
+ timestep: int,
+ sample: Union[paddle.Tensor, np.ndarray],
+ ):
+ """
+ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
+ times to approximate the solution.
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+ if not self.skip_prk_steps and len(self.ets) < 3:
+ raise ValueError(
+ f"{self.__class__} can only be run AFTER scheduler has been run "
+ "in 'prk' mode for at least 12 iterations "
+ "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
+ "for more information.")
+
+ prev_timestep = max(
+ timestep - self.num_train_timesteps // self.num_inference_steps, 0)
+
+ if self.counter != 1:
+ self.ets.append(model_output)
+ else:
+ prev_timestep = timestep
+ timestep = (timestep +
+ self.num_train_timesteps // self.num_inference_steps)
+
+ if len(self.ets) == 1 and self.counter == 0:
+ model_output = model_output
+ self.cur_sample = sample
+ elif len(self.ets) == 1 and self.counter == 1:
+ model_output = (model_output + self.ets[-1]) / 2
+ sample = self.cur_sample
+ self.cur_sample = None
+ elif len(self.ets) == 2:
+ model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
+ elif len(self.ets) == 3:
+ model_output = (23 * self.ets[-1] - 16 * self.ets[-2] +
+ 5 * self.ets[-3]) / 12
+ else:
+ model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] +
+ 37 * self.ets[-3] - 9 * self.ets[-4])
+
+ prev_sample = self._get_prev_sample(sample, timestep, prev_timestep,
+ model_output)
+ self.counter += 1
+
+ return {"prev_sample": prev_sample}
+
+ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
+ # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
+ # this function computes x_(t−δ) using the formula of (9)
+ # Note that x_t needs to be added to both sides of the equation
+
+ # Notation ( ->
+ # alpha_prod_t -> α_t
+ # alpha_prod_t_prev -> α_(t−δ)
+ # beta_prod_t -> (1 - α_t)
+ # beta_prod_t_prev -> (1 - α_(t−δ))
+ # sample -> x_t
+ # model_output -> e_θ(x_t, t)
+ # prev_sample -> x_(t−δ)
+ alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
+ alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 -
+ self._offset]
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # corresponds to (α_(t−δ) - α_t) divided by
+ # denominator of x_t in formula (9) and plus 1
+ # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
+ # sqrt(α_(t−δ)) / sqrt(α_t))
+ sample_coeff = (alpha_prod_t_prev / alpha_prod_t)**(0.5)
+
+ # corresponds to denominator of e_θ(x_t, t) in formula (9)
+ model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev**(0.5) + (
+ alpha_prod_t * beta_prod_t * alpha_prod_t_prev)**(0.5)
+
+ # full formula (9)
+ prev_sample = (sample_coeff * sample -
+ (alpha_prod_t_prev - alpha_prod_t) * model_output /
+ model_output_denom_coeff)
+
+ return prev_sample
+
+ def add_noise(self, original_samples, noise, timesteps):
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps]**0.5
+ sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps])**0.5
+ sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod,
+ original_samples)
+
+ noisy_samples = (sqrt_alpha_prod * original_samples +
+ sqrt_one_minus_alpha_prod * noise)
+ return noisy_samples
+
+ def __len__(self):
+ return self.num_train_timesteps
+
+
+class DDIMScheduler(
+ SchedulerMixin, ):
+
+ def __init__(
+ self,
+ num_train_timesteps=1000,
+ beta_start=0.0001,
+ beta_end=0.02,
+ beta_schedule="linear",
+ trained_betas=None,
+ timestep_values=None,
+ clip_sample=True,
+ set_alpha_to_one=True,
+ tensor_format="pd",
+ ):
+ self.num_train_timesteps = num_train_timesteps
+ self.beta_start = beta_start
+ self.beta_end = beta_end
+ self.beta_schedule = beta_schedule
+ self.trained_betas = trained_betas
+ self.timestep_values = timestep_values
+ self.clip_sample = clip_sample
+ self.set_alpha_to_one = set_alpha_to_one
+ self.tensor_format = tensor_format
+
+ if beta_schedule == "linear":
+ self.betas = np.linspace(beta_start,
+ beta_end,
+ num_train_timesteps,
+ dtype=np.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (np.linspace(
+ beta_start**0.5,
+ beta_end**0.5,
+ num_train_timesteps,
+ dtype=np.float32,
+ )**2)
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(
+ f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this paratemer simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = (np.array(1.0) if set_alpha_to_one else
+ self.alphas_cumprod[0])
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = (self.alphas_cumprod[prev_timestep] if
+ prev_timestep >= 0 else self.final_alpha_cumprod)
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev /
+ beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def set_timesteps(self, num_inference_steps, offset=0):
+ self.num_inference_steps = num_inference_steps
+ self.timesteps = np.arange(
+ 0,
+ self.num_train_timesteps,
+ self.num_train_timesteps // self.num_inference_steps,
+ )[::-1].copy()
+ self.timesteps += offset
+ self.set_format(tensor_format=self.tensor_format)
+
+ def step(self,
+ model_output: Union[paddle.Tensor, np.ndarray],
+ timestep: int,
+ sample: Union[paddle.Tensor, np.ndarray],
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False):
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointingc to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = (timestep -
+ self.num_train_timesteps // self.num_inference_steps)
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = (self.alphas_cumprod[prev_timestep] if
+ prev_timestep >= 0 else self.final_alpha_cumprod)
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (sample - beta_prod_t**
+ (0.5) * model_output) / alpha_prod_t**(0.5)
+
+ # 4. Clip "predicted x_0"
+ if self.clip_sample:
+ pred_original_sample = self.clip(pred_original_sample, -1, 1)
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance**(0.5)
+
+ if use_clipped_model_output:
+ # the model_output is always re-derived from the clipped x_0 in Glide
+ model_output = (sample - alpha_prod_t**
+ (0.5) * pred_original_sample) / beta_prod_t**(0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev -
+ std_dev_t**2)**(0.5) * model_output
+
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = (alpha_prod_t_prev**(0.5) * pred_original_sample +
+ pred_sample_direction)
+
+ if eta > 0:
+ noise = paddle.randn(model_output.shape)
+ variance = (self._get_variance(timestep, prev_timestep)**(0.5) *
+ eta * noise)
+
+ if not paddle.is_tensor(model_output):
+ variance = variance.numpy()
+
+ prev_sample = prev_sample + variance
+
+ return {"prev_sample": prev_sample}
+
+ def add_noise(self, original_samples, noise, timesteps):
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps]**0.5
+ sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps])**0.5
+ sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod,
+ original_samples)
+
+ noisy_samples = (sqrt_alpha_prod * original_samples +
+ sqrt_one_minus_alpha_prod * noise)
+ return noisy_samples
+
+ def __len__(self):
+ return self.num_train_timesteps
diff --git a/paddlenlp/transformers/stable_diffusion_utils/unet_2d_condition.py b/paddlenlp/transformers/stable_diffusion_utils/unet_2d_condition.py
new file mode 100644
index 000000000000..86e3a75601d8
--- /dev/null
+++ b/paddlenlp/transformers/stable_diffusion_utils/unet_2d_condition.py
@@ -0,0 +1,235 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, Union
+
+import paddle
+import paddle.nn as nn
+
+from .embeddings import TimestepEmbedding, Timesteps
+from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
+
+
+class UNet2DConditionModel(nn.Layer):
+
+ def __init__(
+ self,
+ sample_size=64,
+ in_channels=4,
+ out_channels=4,
+ center_input_sample=False,
+ flip_sin_to_cos=True,
+ freq_shift=0,
+ down_block_types=(
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ up_block_types=(
+ "UpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ ),
+ block_out_channels=(320, 640, 1280, 1280),
+ layers_per_block=2,
+ downsample_padding=1,
+ mid_block_scale_factor=1,
+ act_fn="silu",
+ norm_num_groups=32,
+ norm_eps=1e-5,
+ cross_attention_dim=768,
+ attention_head_dim=8,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.sample_size = sample_size
+ self.center_input_sample = center_input_sample
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = nn.Conv2D(in_channels,
+ block_out_channels[0],
+ kernel_size=3,
+ padding=(1, 1))
+
+ # time
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos,
+ freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim,
+ time_embed_dim)
+
+ self.down_blocks = nn.LayerList([])
+ self.mid_block = None
+ self.up_blocks = nn.LayerList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim,
+ downsample_padding=downsample_padding,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift="default",
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim,
+ resnet_groups=norm_num_groups,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(
+ i + 1,
+ len(block_out_channels) - 1)]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0],
+ num_groups=norm_num_groups,
+ epsilon=norm_eps,
+ )
+ self.conv_act = nn.Silu()
+ self.conv_out = nn.Conv2D(block_out_channels[0],
+ out_channels,
+ 3,
+ padding=1)
+
+ def forward(
+ self,
+ sample: paddle.Tensor,
+ timestep: Union[paddle.Tensor, float, int],
+ encoder_hidden_states: paddle.Tensor,
+ ) -> Dict[str, paddle.Tensor]:
+
+ # 0. center input if necessary
+ if self.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not paddle.is_tensor(timesteps):
+ timesteps = paddle.to_tensor([timesteps], dtype="int64")
+ elif paddle.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps[None]
+
+ # broadcast to batch dimension
+ timesteps = paddle.broadcast_to(timesteps, [sample.shape[0]])
+
+ t_emb = self.time_proj(timesteps)
+ emb = self.time_embedding(t_emb)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample, )
+ for downsample_block in self.down_blocks:
+
+ if (hasattr(downsample_block, "attentions")
+ and downsample_block.attentions is not None):
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample,
+ temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states)
+
+ # 5. up
+ for upsample_block in self.up_blocks:
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
+ down_block_res_samples = down_block_res_samples[:-len(upsample_block
+ .resnets)]
+
+ if (hasattr(upsample_block, "attentions")
+ and upsample_block.attentions is not None):
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ else:
+ sample = upsample_block(hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples)
+
+ # 6. post-process
+ # make sure hidden states is in float32
+ # when running in half-precision
+ sample = self.conv_norm_out(sample.astype("float32")).astype(
+ sample.dtype)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ output = {"sample": sample}
+
+ return output
diff --git a/paddlenlp/transformers/stable_diffusion_utils/unet_blocks.py b/paddlenlp/transformers/stable_diffusion_utils/unet_blocks.py
new file mode 100644
index 000000000000..e29b040b01e6
--- /dev/null
+++ b/paddlenlp/transformers/stable_diffusion_utils/unet_blocks.py
@@ -0,0 +1,1513 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+
+from .attention import AttentionBlock, SpatialTransformer
+from .resnet import (
+ Downsample2D,
+ FirDownsample2D,
+ FirUpsample2D,
+ ResnetBlock2D,
+ Upsample2D,
+)
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ cross_attention_dim=None,
+ downsample_padding=None,
+):
+ down_block_type = (down_block_type[7:]
+ if down_block_type.startswith("UNetRes") else
+ down_block_type)
+ if down_block_type == "DownBlock2D":
+ return DownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ )
+ elif down_block_type == "AttnDownBlock2D":
+ return AttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "CrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError(
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D")
+ return CrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "SkipDownBlock2D":
+ return SkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ )
+ elif down_block_type == "AttnSkipDownBlock2D":
+ return AttnSkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "DownEncoderBlock2D":
+ return DownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ )
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ cross_attention_dim=None,
+):
+ up_block_type = (up_block_type[7:]
+ if up_block_type.startswith("UNetRes") else up_block_type)
+ if up_block_type == "UpBlock2D":
+ return UpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "CrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError(
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D")
+ return CrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "AttnUpBlock2D":
+ return AttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "SkipUpBlock2D":
+ return SkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "AttnSkipUpBlock2D":
+ return AttnSkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "UpDecoderBlock2D":
+ return UpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.attention_type = attention_type
+ resnet_groups = (resnet_groups if resnet_groups is not None else min(
+ in_channels // 4, 32))
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ attentions.append(
+ AttentionBlock(
+ in_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ ))
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+
+ self.attentions = nn.LayerList(attentions)
+ self.resnets = nn.LayerList(resnets)
+
+ def forward(self, hidden_states, temb=None, encoder_states=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if self.attention_type == "default":
+ hidden_states = attn(hidden_states)
+ else:
+ hidden_states = attn(hidden_states, encoder_states)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class UNetMidBlock2DCrossAttn(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.attention_type = attention_type
+ resnet_groups = (resnet_groups if resnet_groups is not None else min(
+ in_channels // 4, 32))
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ attentions.append(
+ SpatialTransformer(
+ in_channels,
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ depth=1,
+ context_dim=cross_attention_dim,
+ ))
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+
+ self.attentions = nn.LayerList(attentions)
+ self.resnets = nn.LayerList(resnets)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class AttnDownBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ ))
+
+ self.attentions = nn.LayerList(attentions)
+ self.resnets = nn.LayerList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.LayerList([
+ Downsample2D(
+ in_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
+ )
+ ])
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states += (hidden_states, )
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states, )
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+ attentions.append(
+ SpatialTransformer(
+ out_channels,
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ depth=1,
+ context_dim=cross_attention_dim,
+ ))
+ self.attentions = nn.LayerList(attentions)
+ self.resnets = nn.LayerList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.LayerList([
+ Downsample2D(
+ in_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
+ )
+ ])
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
+ output_states += (hidden_states, )
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states, )
+
+ return hidden_states, output_states
+
+
+class DownBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+
+ self.resnets = nn.LayerList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.LayerList([
+ Downsample2D(
+ in_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
+ )
+ ])
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb)
+ output_states += (hidden_states, )
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states, )
+
+ return hidden_states, output_states
+
+
+class DownEncoderBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+
+ self.resnets = nn.LayerList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.LayerList([
+ Downsample2D(
+ in_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
+ )
+ ])
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnDownEncoderBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ ))
+
+ self.attentions = nn.LayerList(attentions)
+ self.resnets = nn.LayerList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.LayerList([
+ Downsample2D(
+ in_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
+ )
+ ])
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = attn(hidden_states)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipDownBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=np.sqrt(2.0),
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ self.attentions = nn.LayerList([])
+ self.resnets = nn.LayerList([])
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+ self.attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ ))
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.LayerList(
+ [FirDownsample2D(in_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2D(3,
+ out_channels,
+ kernel_size=(1, 1),
+ stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(self, hidden_states, temb=None, skip_sample=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states += (hidden_states, )
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states, )
+
+ return hidden_states, output_states, skip_sample
+
+
+class SkipDownBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor=np.sqrt(2.0),
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ self.resnets = nn.LayerList([])
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.LayerList(
+ [FirDownsample2D(in_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2D(3,
+ out_channels,
+ kernel_size=(1, 1),
+ stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(self, hidden_states, temb=None, skip_sample=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb)
+ output_states += (hidden_states, )
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states, )
+
+ return hidden_states, output_states, skip_sample
+
+
+class AttnUpBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_type="default",
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers -
+ 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ ))
+
+ self.attentions = nn.LayerList(attentions)
+ self.resnets = nn.LayerList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.LayerList([
+ Upsample2D(out_channels,
+ use_conv=True,
+ out_channels=out_channels)
+ ])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ for resnet, attn in zip(self.resnets, self.attentions):
+
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = paddle.concat([hidden_states, res_hidden_states],
+ axis=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class CrossAttnUpBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers -
+ 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+ attentions.append(
+ SpatialTransformer(
+ out_channels,
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ depth=1,
+ context_dim=cross_attention_dim,
+ ))
+ self.attentions = nn.LayerList(attentions)
+ self.resnets = nn.LayerList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.LayerList([
+ Upsample2D(out_channels,
+ use_conv=True,
+ out_channels=out_channels)
+ ])
+ else:
+ self.upsamplers = None
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = paddle.concat([hidden_states, res_hidden_states],
+ axis=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class UpBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers -
+ 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+
+ self.resnets = nn.LayerList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.LayerList([
+ Upsample2D(out_channels,
+ use_conv=True,
+ out_channels=out_channels)
+ ])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ for resnet in self.resnets:
+
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = paddle.concat([hidden_states, res_hidden_states],
+ axis=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class UpDecoderBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+
+ self.resnets = nn.LayerList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.LayerList([
+ Upsample2D(out_channels,
+ use_conv=True,
+ out_channels=out_channels)
+ ])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnUpDecoderBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ ))
+
+ self.attentions = nn.LayerList(attentions)
+ self.resnets = nn.LayerList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.LayerList([
+ Upsample2D(out_channels,
+ use_conv=True,
+ out_channels=out_channels)
+ ])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = attn(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipUpBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=np.sqrt(2.0),
+ upsample_padding=1,
+ add_upsample=True,
+ ):
+ super().__init__()
+ self.attentions = nn.LayerList([])
+ self.resnets = nn.LayerList([])
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers -
+ 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+
+ self.attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ ))
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2D(out_channels,
+ 3,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1))
+ self.skip_norm = nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32),
+ num_channels=out_channels,
+ eps=resnet_eps,
+ affine=True,
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ def forward(self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ skip_sample=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = paddle.concat([hidden_states, res_hidden_states],
+ axis=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = self.attentions[0](hidden_states)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
+
+
+class SkipUpBlock2D(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor=np.sqrt(2.0),
+ add_upsample=True,
+ upsample_padding=1,
+ ):
+ super().__init__()
+ self.resnets = nn.LayerList([])
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers -
+ 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min((resnet_in_channels + res_skip_channels) // 4,
+ 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ ))
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2D(out_channels,
+ 3,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1))
+ self.skip_norm = nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32),
+ num_channels=out_channels,
+ eps=resnet_eps,
+ affine=True,
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ def forward(self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ skip_sample=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = paddle.concat([hidden_states, res_hidden_states],
+ axis=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
diff --git a/paddlenlp/transformers/stable_diffusion_utils/utils.py b/paddlenlp/transformers/stable_diffusion_utils/utils.py
new file mode 100644
index 000000000000..a11482596a03
--- /dev/null
+++ b/paddlenlp/transformers/stable_diffusion_utils/utils.py
@@ -0,0 +1,486 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import random
+
+import numpy as np
+import paddle
+from PIL import Image
+from tqdm.auto import tqdm
+from typing import Optional
+
+from .schedulers import PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler
+from ..image_utils import load_image
+
+__all__ = ["StableDiffusionMixin"]
+
+
+class StableDiffusionMixin:
+
+ def set_scheduler(self, scheduler):
+ if isinstance(scheduler,
+ (PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler)):
+ self.scheduler = scheduler
+ elif isinstance(scheduler, str):
+ if scheduler == "pndm":
+ self.scheduler = PNDMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ skip_prk_steps=True,
+ )
+ elif scheduler == "ddim":
+ self.scheduler = DDIMScheduler(beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False)
+ elif scheduler == "k-lms":
+ self.scheduler = LMSDiscreteScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear")
+ else:
+ raise ValueError(
+ 'scheduler must be in ["pndm", "ddim", "k-lms"].')
+ else:
+ raise ValueError('scheduler error.')
+
+ @classmethod
+ def preprocess_image(cls, image):
+ image = load_image(image)
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32,
+ (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose([0, 3, 1, 2])
+ image = paddle.to_tensor(image)
+ return 2.0 * image - 1.0
+
+ @classmethod
+ def preprocess_mask(cls, mask):
+ mask = load_image(mask)
+ mask = mask.convert("L")
+ w, h = mask.size
+ w, h = map(lambda x: x - x % 32,
+ (w, h)) # resize to integer multiple of 32
+ mask = mask.resize((w // 8, h // 8), resample=Image.NEAREST)
+ mask = np.array(mask).astype(np.float32) / 255.0
+ mask = np.tile(mask, (4, 1, 1))
+ mask = mask[None].transpose([0, 1, 2, 3]) # what does this step do?
+ mask = 1 - mask # repaint white, keep black
+ mask = paddle.to_tensor(mask)
+ return mask
+
+ @paddle.no_grad()
+ def stable_diffusion_text2image(
+ self,
+ input_ids,
+ seed: Optional[int] = None,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ latents: Optional[paddle.Tensor] = None,
+ fp16: Optional[bool] = False,
+ ):
+ batch_size = input_ids.shape[0]
+ if height % 64 != 0 or width % 64 != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by 64 but are {height} and {width}."
+ )
+
+ with paddle.amp.auto_cast(enable=fp16, level="O1"):
+ text_embeddings = self.clip.text_model(input_ids)[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ uncond_embeddings = self.clip.text_model(
+ self.input_ids_uncond.expand([batch_size, -1]))[0]
+ text_embeddings = paddle.concat(
+ [uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+ latents_shape = [
+ batch_size, self.unet_model.in_channels, height // 8, width // 8
+ ]
+ if latents is None:
+ if seed is None:
+ seed = random.randint(0, 2**32)
+ paddle.seed(seed)
+ latents = paddle.randn(latents_shape)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(
+ f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}"
+ )
+
+ # set timesteps
+ accepts_offset = "offset" in set(
+ inspect.signature(
+ self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps,
+ **extra_set_kwargs)
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents * self.scheduler.sigmas[0]
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(
+ inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in tqdm(enumerate(self.scheduler.timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (paddle.concat([latents] * 2) if
+ do_classifier_free_guidance else latents)
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[i]
+ latent_model_input = latent_model_input / (
+ (sigma**2 + 1)**0.5)
+
+ # predict the noise residual
+ noise_pred = self.unet_model(
+ latent_model_input,
+ t,
+ encoder_hidden_states=text_embeddings)["sample"]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = self.scheduler.step(
+ noise_pred, i, latents,
+ **extra_step_kwargs)["prev_sample"]
+ else:
+ latents = self.scheduler.step(
+ noise_pred, t, latents,
+ **extra_step_kwargs)["prev_sample"]
+
+ # scale and decode the image latents with vae
+ image = self.vae_model.decode(1 / 0.18215 * latents)
+ image = (image / 2 + 0.5).clip(0, 1)
+ image = image.transpose([0, 2, 3, 1]).cpu().numpy()
+ image = (image * 255).round().astype(np.uint8)
+ image = [Image.fromarray(img) for img in image]
+
+ return image
+
+ @paddle.no_grad()
+ def stable_diffusion_image2image(
+ self,
+ input_ids,
+ init_image,
+ strength=0.8,
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ eta=0.0,
+ seed=None,
+ fp16=False,
+ ):
+ batch_size = input_ids.shape[0]
+
+ if strength < 0 or strength > 1:
+ raise ValueError(
+ f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ with paddle.amp.auto_cast(enable=fp16, level="O1"):
+ # set timesteps
+ accepts_offset = "offset" in set(
+ inspect.signature(
+ self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ offset = 0
+ if accepts_offset:
+ offset = 1
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps,
+ **extra_set_kwargs)
+
+ # encode the init image into latents and scale the latents
+ init_latents = self.vae_model.encode(init_image).sample()
+ init_latents = 0.18215 * init_latents
+
+ # prepare init_latents noise to latents
+ init_latents = paddle.concat([init_latents] * batch_size)
+
+ # get the original timestep using init_timestep
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ timesteps = paddle.to_tensor(
+ [num_inference_steps - init_timestep] * batch_size,
+ dtype="int64")
+ else:
+ timesteps = self.scheduler.timesteps[-init_timestep]
+ timesteps = paddle.to_tensor([timesteps] * batch_size,
+ dtype="int64")
+
+ # add noise to latents using the timesteps
+ if seed is None:
+ seed = random.randint(0, 2**32)
+ paddle.seed(seed)
+ noise = paddle.randn(init_latents.shape)
+ init_latents = self.scheduler.add_noise(init_latents, noise,
+ timesteps)
+
+ text_embeddings = self.clip.text_model(input_ids)[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ uncond_embeddings = self.clip.text_model(
+ self.input_ids_uncond.expand([batch_size, -1]))[0]
+ text_embeddings = paddle.concat(
+ [uncond_embeddings, text_embeddings])
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(
+ inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ latents = init_latents
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
+ t_index = t_start + i
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (paddle.concat([latents] * 2) if
+ do_classifier_free_guidance else latents)
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[t_index]
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
+ latent_model_input = latent_model_input / (
+ (sigma**2 + 1)**0.5)
+ latent_model_input = latent_model_input.astype(
+ paddle.get_default_dtype())
+ t = t.astype(paddle.get_default_dtype())
+
+ # predict the noise residual
+ noise_pred = self.unet_model(
+ latent_model_input,
+ t,
+ encoder_hidden_states=text_embeddings)["sample"]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = self.scheduler.step(
+ noise_pred, t_index, latents,
+ **extra_step_kwargs)["prev_sample"]
+ else:
+ latents = self.scheduler.step(
+ noise_pred, t, latents,
+ **extra_step_kwargs)["prev_sample"]
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vae_model.decode(
+ latents.astype(paddle.get_default_dtype()))
+ image = (image / 2 + 0.5).clip(0, 1)
+ image = image.transpose([0, 2, 3, 1]).cpu().numpy()
+ image = (image * 255).round().astype(np.uint8)
+ image = [Image.fromarray(img) for img in image]
+ return image
+
+ @paddle.no_grad()
+ def stable_diffusion_inpainting(
+ self,
+ input_ids,
+ init_image,
+ mask_image,
+ strength=0.8,
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ eta=0.0,
+ seed=None,
+ fp16=False,
+ ):
+ batch_size = input_ids.shape[0]
+
+ if strength < 0 or strength > 1:
+ raise ValueError(
+ f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ with paddle.amp.auto_cast(enable=fp16, level="O1"):
+ # set timesteps
+ accepts_offset = "offset" in set(
+ inspect.signature(
+ self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ offset = 0
+ if accepts_offset:
+ offset = 1
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps,
+ **extra_set_kwargs)
+
+ # encode the init image into latents and scale the latents
+ init_latents = self.vae_model.encode(init_image).sample()
+ init_latents = 0.18215 * init_latents
+
+ # prepare init_latents noise to latents
+ init_latents = paddle.concat([init_latents] * batch_size)
+ init_latents_orig = init_latents
+
+ mask = paddle.concat([mask_image] * batch_size)
+
+ # check sizes
+ if not mask.shape == init_latents.shape:
+ raise ValueError(
+ f"The mask and init_image should be the same size!")
+
+ # get the original timestep using init_timestep
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ timesteps = paddle.to_tensor(
+ [num_inference_steps - init_timestep] * batch_size,
+ dtype="int64")
+ else:
+ timesteps = self.scheduler.timesteps[-init_timestep]
+ timesteps = paddle.to_tensor([timesteps] * batch_size,
+ dtype="int64")
+
+ # add noise to latents using the timesteps
+ if seed is None:
+ seed = random.randint(0, 2**32)
+ paddle.seed(seed)
+ noise = paddle.randn(init_latents.shape)
+ init_latents = self.scheduler.add_noise(init_latents, noise,
+ timesteps)
+
+ # get prompt text embeddings
+ text_embeddings = self.clip.text_model(input_ids)[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ uncond_embeddings = self.clip.text_model(
+ self.input_ids_uncond.expand([batch_size, -1]))[0]
+ text_embeddings = paddle.concat(
+ [uncond_embeddings, text_embeddings])
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(
+ inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ latents = init_latents
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
+ t_index = t_start + i
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (paddle.concat([latents] * 2) if
+ do_classifier_free_guidance else latents)
+
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[t_index]
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
+ latent_model_input = latent_model_input / (
+ (sigma**2 + 1)**0.5)
+ latent_model_input = latent_model_input.astype(
+ paddle.get_default_dtype())
+ t = t.astype(paddle.get_default_dtype())
+
+ # predict the noise residual
+ noise_pred = self.unet_model(
+ latent_model_input,
+ t,
+ encoder_hidden_states=text_embeddings)["sample"]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = self.scheduler.step(
+ noise_pred, t_index, latents,
+ **extra_step_kwargs)["prev_sample"]
+ # masking
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_orig, noise, t_index)
+ else:
+ latents = self.scheduler.step(
+ noise_pred, t, latents,
+ **extra_step_kwargs)["prev_sample"]
+
+ # masking
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_orig, noise, t)
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
+
+ # scale and decode the image latents with vae
+ image = self.vae_model.decode(1 / 0.18215 * latents)
+ image = (image / 2 + 0.5).clip(0, 1)
+ image = image.transpose([0, 2, 3, 1]).cpu().numpy()
+ image = (image * 255).round().astype(np.uint8)
+ image = [Image.fromarray(img) for img in image]
+ return image
diff --git a/paddlenlp/transformers/stable_diffusion_utils/vae.py b/paddlenlp/transformers/stable_diffusion_utils/vae.py
new file mode 100644
index 000000000000..b18e5dc3a659
--- /dev/null
+++ b/paddlenlp/transformers/stable_diffusion_utils/vae.py
@@ -0,0 +1,519 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+
+from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
+
+
+class Encoder(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownEncoderBlock2D", ),
+ block_out_channels=(64, ),
+ layers_per_block=2,
+ act_fn="silu",
+ double_z=True,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = nn.Conv2D(in_channels,
+ block_out_channels[0],
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ self.mid_block = None
+ self.down_blocks = nn.LayerList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=self.layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-6,
+ downsample_padding=0,
+ resnet_act_fn=act_fn,
+ attn_num_head_channels=None,
+ temb_channels=None,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=None,
+ resnet_groups=32,
+ temb_channels=None,
+ )
+
+ # out
+ num_groups_out = 32
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1],
+ num_groups=num_groups_out,
+ epsilon=1e-6)
+ self.conv_act = nn.Silu()
+
+ conv_out_channels = 2 * out_channels if double_z else out_channels
+ self.conv_out = nn.Conv2D(block_out_channels[-1],
+ conv_out_channels,
+ 3,
+ padding=1)
+
+ def forward(self, x):
+ sample = x
+ sample = self.conv_in(sample)
+
+ # down
+ for down_block in self.down_blocks:
+ sample = down_block(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class Decoder(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ up_block_types=("UpDecoderBlock2D", ),
+ block_out_channels=(64, ),
+ layers_per_block=2,
+ act_fn="silu",
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = nn.Conv2D(in_channels,
+ block_out_channels[-1],
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ self.mid_block = None
+ self.up_blocks = nn.LayerList([])
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=None,
+ resnet_groups=32,
+ temb_channels=None,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=self.layers_per_block + 1,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ prev_output_channel=None,
+ add_upsample=not is_final_block,
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ attn_num_head_channels=None,
+ temb_channels=None,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ num_groups_out = 32
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0],
+ num_groups=num_groups_out,
+ epsilon=1e-6)
+ self.conv_act = nn.Silu()
+ self.conv_out = nn.Conv2D(block_out_channels[0],
+ out_channels,
+ 3,
+ padding=1)
+
+ def forward(self, z):
+ sample = z
+ sample = self.conv_in(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = up_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class VectorQuantizer(nn.Layer):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
+ multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(
+ self,
+ n_e,
+ e_dim,
+ beta,
+ remap=None,
+ unknown_index="random",
+ sane_index_shape=False,
+ legacy=True,
+ ):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", paddle.to_tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape([ishape[0], -1])
+ used = self.used
+ match = (inds[:, :, None] == used[None, None, ...]).astype("int64")
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = paddle.randint(0,
+ self.re_embed,
+ shape=new[unknown].shape)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape([ishape[0], -1])
+ used = self.used
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = paddle.gather(used[None, :][inds.shape[0] * [0], :],
+ inds,
+ axis=1)
+ return back.reshape(ishape)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.transpose([0, 2, 3, 1])
+ z_flattened = z.reshape([-1, self.e_dim])
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = (paddle.sum(z_flattened**2, axis=1, keepdim=True) +
+ paddle.sum(self.embedding.weight**2, axis=1) - 2 *
+ paddle.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t()))
+
+ min_encoding_indices = paddle.argmin(d, axis=1)
+ z_q = self.embedding(min_encoding_indices).reshape(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * paddle.mean((z_q.detach() - z)**2) + paddle.mean(
+ (z_q - z.detach())**2)
+ else:
+ loss = paddle.mean((z_q.detach() - z)**2) + self.beta * paddle.mean(
+ (z_q - z.detach())**2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = z_q.transpose([0, 3, 1, 2])
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(
+ [z.shape[0], -1]) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape([-1,
+ 1]) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(
+ [z_q.shape[0], z_q.shape[2], z_q.shape[3]])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape([shape[0], -1]) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.flatten() # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.reshape(shape)
+ # reshape back to match original input shape
+ z_q = z_q.transpose([0, 3, 1, 2])
+
+ return z_q
+
+
+class DiagonalGaussianDistribution(object):
+
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = paddle.chunk(parameters, 2, axis=1)
+ self.logvar = paddle.clip(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = paddle.exp(0.5 * self.logvar)
+ self.var = paddle.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = paddle.zeros_like(self.mean)
+
+ def sample(self, seed=None):
+ if seed is not None:
+ paddle.seed(seed)
+ x = self.mean + self.std * paddle.randn(self.mean.shape)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return paddle.to_tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * paddle.sum(
+ paddle.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ axis=[1, 2, 3],
+ )
+ else:
+ return 0.5 * paddle.sum(
+ paddle.pow(self.mean - other.mean, 2) / other.var +
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
+ axis=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return paddle.to_tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * paddle.sum(
+ logtwopi + self.logvar +
+ paddle.pow(sample - self.mean, 2) / self.var,
+ axis=dims,
+ )
+
+ def mode(self):
+ return self.mean
+
+
+class VQModel(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownEncoderBlock2D", ),
+ up_block_types=("UpDecoderBlock2D", ),
+ block_out_channels=(64, ),
+ layers_per_block=1,
+ act_fn="silu",
+ latent_channels=3,
+ sample_size=32,
+ num_vq_embeddings=256,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ double_z=False,
+ )
+
+ self.quant_conv = nn.Conv2D(latent_channels, latent_channels, 1)
+ self.quantize = VectorQuantizer(
+ num_vq_embeddings,
+ latent_channels,
+ beta=0.25,
+ remap=None,
+ sane_index_shape=False,
+ )
+ self.post_quant_conv = nn.Conv2D(latent_channels, latent_channels, 1)
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ )
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ return h
+
+ def decode(self, h, force_not_quantize=False):
+ # also go through quantization layer
+ if not force_not_quantize:
+ quant, emb_loss, info = self.quantize(h)
+ else:
+ quant = h
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+ def forward(self, sample):
+ x = sample
+ h = self.encode(x)
+ dec = self.decode(h)
+ return dec
+
+
+class AutoencoderKL(nn.Layer):
+
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=(
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ),
+ up_block_types=(
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ ),
+ block_out_channels=(128, 256, 512, 512),
+ layers_per_block=2,
+ act_fn="silu",
+ latent_channels=4,
+ sample_size=512,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ double_z=True,
+ )
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ )
+
+ self.quant_conv = nn.Conv2D(2 * latent_channels, 2 * latent_channels, 1)
+ self.post_quant_conv = nn.Conv2D(latent_channels, latent_channels, 1)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, sample, sample_posterior=False):
+ x = sample
+ posterior = self.encode(x)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec