Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump diffusers to 0.12.1 #302

Merged
merged 1 commit into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions notebooks/stable_diffusion/image_to_image.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@
" text_encoder_ipu_config=text_encoder_ipu_config,\n",
" vae_ipu_config=vae_ipu_config,\n",
" safety_checker_ipu_config=safety_checker_ipu_config\n",
")\n",
"pipe.enable_attention_slicing()"
")"
]
},
{
Expand Down
3 changes: 1 addition & 2 deletions notebooks/stable_diffusion/inpainting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@
" text_encoder_ipu_config=text_encoder_ipu_config,\n",
" vae_ipu_config=vae_ipu_config,\n",
" safety_checker_ipu_config=safety_checker_ipu_config\n",
")\n",
"pipe.enable_attention_slicing()"
")"
]
},
{
Expand Down
3 changes: 1 addition & 2 deletions notebooks/stable_diffusion/text_to_image.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@
" text_encoder_ipu_config=text_encoder_ipu_config,\n",
" vae_ipu_config=vae_ipu_config,\n",
" safety_checker_ipu_config=safety_checker_ipu_config\n",
")\n",
"pipe.enable_attention_slicing()"
")"
]
},
{
Expand Down
3 changes: 1 addition & 2 deletions notebooks/stable_diffusion/text_to_image_sd2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@
" vae_ipu_config=vae_ipu_config,\n",
" safety_checker_ipu_config=safety_checker_ipu_config\n",
")\n",
"pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n",
"pipe.enable_attention_slicing()"
"pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import poptorch
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.models.attention import CrossAttention
from diffusers.models.cross_attention import CrossAttention
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.vae import DecoderOutput
from optimum.utils import logging
Expand All @@ -35,40 +35,45 @@
logger = logging.get_logger(__name__)


class IPUCrossAttention(CrossAttention):
class IPUSlicedAttnProcessor:
"""
SlicedAttnProcessor but we slice across the query sequence length instead of across heads.
NB: this ignores the `slice_size` factor since we interpret it differently and use a value that is
derived from the sequence length based on an empirical attention matrix memory target.
"""

def __init__(self, attn_matrix_target_mem_mb: int):
if attn_matrix_target_mem_mb < 1:
raise ValueError(f"`attn_matrix_target_mem_mb` {attn_matrix_target_mem_mb} must be a positive integer.")

self._attn_matrix_target_mem_mb = attn_matrix_target_mem_mb

@staticmethod
def _nearest_divisor(target, start, end):
for divisor in range(start, end + 1):
if target % divisor == 0:
return divisor
raise ValueError(f"No divisor found in range [{start}, {end}].")

def _attention(self, query, key, value, attention_mask):
"""Overriding this implementation as the `torch.baddbmm` op is not registered."""
attention_scores = torch.matmul(query, key.transpose(1, 2)) * self.scale
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape

if attention_mask is not None:
attention_scores = attention_scores + attention_mask
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)

attention_probs = attention_scores.softmax(dim=-1)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)

hidden_states = torch.bmm(attention_probs, value)

# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
"""
Overriding this implementation to slice across the query sequence length instead of across heads.
NB: this ignores the `slice_size` factor since we interpret it differently and use a value that is
derived from the sequence length based on an empirical attention matrix memory target.
"""
# Begin IPU modifications.
attn_matrix_mem = query.element_size() * query.shape[0] * query.shape[1] * key.shape[1]
num_slices = attn_matrix_mem // (self._attn_matrix_target_mem_mb * 1024 * 1024)
if num_slices < 2:
return self._attention(query, key, value, attention_mask)

num_slices = max(num_slices, 1)
num_slices = self._nearest_divisor(query.shape[1], num_slices, 2 * num_slices)
slice_size = query.shape[1] // num_slices

Expand All @@ -79,7 +84,7 @@ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_m
start_idx = i * slice_size
end_idx = (i + 1) * slice_size

attn_slice = torch.matmul(query[:, start_idx:end_idx], key) * self.scale
attn_slice = torch.matmul(query[:, start_idx:end_idx], key) * attn.scale
if attention_mask is not None:
attn_slice = attn_slice + attention_mask[:, start_idx:end_idx]
attn_slice = attn_slice.softmax(dim=-1)
Expand All @@ -88,9 +93,15 @@ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_m
hidden_states.append(attn_slice)

hidden_states = torch.cat(hidden_states, dim=1)
# End IPU modifications.

hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states


Expand Down Expand Up @@ -135,16 +146,15 @@ def forward(


class IPUUNet2DConditionModel(UNet2DConditionModel, PipelineMixin):
def change_cross_attention_class(self, attn_matrix_target_mem_mb=None):
def change_cross_attention_processor(self, attn_matrix_target_mem_mb):
for module in self.modules():
if isinstance(module, CrossAttention):
module.__class__ = IPUCrossAttention
module._attn_matrix_target_mem_mb = attn_matrix_target_mem_mb
module.set_processor(IPUSlicedAttnProcessor(attn_matrix_target_mem_mb))

def parallelize(self, attn_matrix_target_mem_mb=None):
super().parallelize()

self.change_cross_attention_class(attn_matrix_target_mem_mb=attn_matrix_target_mem_mb)
self.change_cross_attention_processor(attn_matrix_target_mem_mb)

self.conv_in = poptorch.BeginBlock(self.conv_in, "conv_in", ipu_id=0)
self.down_blocks[2].downsamplers[0] = poptorch.BeginBlock(
Expand Down Expand Up @@ -387,6 +397,14 @@ def from_pretrained(
**kwargs,
)

def set_attention_slice(self, slice_size: Optional[int]):
# Another side effect of letting this go through is that CrossAttention could set
# a different processor than what we intended, so do the simple thing for now.
logger.warn(
"Attention slicing is enabled by default. Specifying a custom value "
"for the `slice_size` is currently unsupported."
)

def detach_from_device(self):
for module in [self.text_encoder, self.unet.unet, self.vae, self.safety_checker]:
if not isinstance(module, poptorch.PoplarExecutor):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
INSTALL_REQUIRES = [
"transformers==4.25.1",
"optimum==1.6.1",
"diffusers[torch]==0.11.1",
"diffusers[torch]==0.12.1",
"datasets",
"tokenizers",
"torch @ https://download.pytorch.org/whl/cpu/torch-1.13.1%2Bcpu-cp38-cp38-linux_x86_64.whl",
Expand Down