Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support bfloat16 for Upsample2D #9480

Merged
merged 7 commits into from
Oct 2, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/diffusers/models/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version

from ..utils import deprecate
from .normalization import RMSNorm


is_torch_less_than_2_1 = version.parse(version.parse(torch.__version__).base_version) < version.parse("2.1")


class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution.

Expand Down Expand Up @@ -151,11 +155,10 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None
if self.use_conv_transpose:
return self.conv(hidden_states)

# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1
# https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
if dtype == torch.bfloat16 and is_torch_less_than_2_1:
hidden_states = hidden_states.to(torch.float32)

# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
Expand All @@ -170,8 +173,8 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")

# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
# Cast back to original dtype
if dtype == torch.bfloat16 and is_torch_less_than_2_1:
hidden_states = hidden_states.to(dtype)

# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
Expand Down