Skip to content

Commit 197c5c0

Browse files
weiwangmetaMichael Gschwind
and
Michael Gschwind
authored
Fix cuda/cpu check on NoneType (#88854) (#90068)
Summary: Fix cuda/cpu check on NoneType Test Plan: sabdcastle/ github CI/CD Differential Revision: D41203955 Pull Request resolved: #88854 Approved by: https://github.com/drisspg, https://github.com/ngimel Co-authored-by: Michael Gschwind <mikekg@meta.com>
1 parent aadbeb7 commit 197c5c0

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch/nn/modules/activation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1116,7 +1116,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
11161116
# generator expressions.
11171117
if torch.overrides.has_torch_function(tensor_args):
11181118
why_not_fast_path = "some Tensor argument has_torch_function"
1119-
elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
1119+
elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
11201120
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
11211121
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
11221122
why_not_fast_path = ("grad is enabled and at least one of query or the "

0 commit comments

Comments
 (0)