From cf53584ed87268447804cf40199c097845cc3d1a Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Wed, 26 Oct 2022 23:07:06 +0000 Subject: [PATCH 1/2] fix `upsample_nearest_nhwc` for large bsz --- src/diffusers/models/resnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index fbd78b512a6b..69d39a523c03 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -48,6 +48,10 @@ def forward(self, hidden_states, output_size=None): if dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.float32) + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` if output_size is None: From f9fa9860b3af25455613e20296a0650867fd790a Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 27 Oct 2022 14:10:23 +0000 Subject: [PATCH 2/2] fix `upsample_nearest_nhwc` for large bsz --- src/diffusers/models/resnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 69d39a523c03..7bb5416adf24 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -380,6 +380,10 @@ def forward(self, input_tensor, temb): hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() input_tensor = self.upsample(input_tensor) hidden_states = self.upsample(hidden_states) elif self.downsample is not None: