Skip to content

Commit 1cad980

Browse files
committed
remove torch.nn.attention dependnecy
1 parent b1adf50 commit 1cad980

File tree

2 files changed

+46
-46
lines changed

2 files changed

+46
-46
lines changed

ldm/modules/attention.py

+42-41
Original file line numberDiff line numberDiff line change
@@ -26,40 +26,40 @@
2626
# import apex
2727
# from apex.normalization import FusedRMSNorm as RMSNorm
2828

29-
if version.parse(torch.__version__) >= version.parse("2.0.0"):
30-
SDP_IS_AVAILABLE = True
31-
# from torch.backends.cuda import SDPBackend, sdp_kernel
32-
from torch.nn.attention import sdpa_kernel, SDPBackend
33-
34-
BACKEND_MAP = {
35-
SDPBackend.MATH: {
36-
"enable_math": True,
37-
"enable_flash": False,
38-
"enable_mem_efficient": False,
39-
},
40-
SDPBackend.FLASH_ATTENTION: {
41-
"enable_math": False,
42-
"enable_flash": True,
43-
"enable_mem_efficient": False,
44-
},
45-
SDPBackend.EFFICIENT_ATTENTION: {
46-
"enable_math": False,
47-
"enable_flash": False,
48-
"enable_mem_efficient": True,
49-
},
50-
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
51-
}
52-
else:
53-
from contextlib import nullcontext
54-
55-
SDP_IS_AVAILABLE = False
56-
sdpa_kernel = nullcontext
57-
BACKEND_MAP = {}
58-
logpy.warn(
59-
f"No SDP backend available, likely because you are running in pytorch "
60-
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
61-
f"You might want to consider upgrading."
62-
)
29+
# if version.parse(torch.__version__) >= version.parse("2.0.0"):
30+
# SDP_IS_AVAILABLE = True
31+
# # from torch.backends.cuda import SDPBackend, sdp_kernel
32+
# from torch.nn.attention import sdpa_kernel, SDPBackend
33+
34+
# BACKEND_MAP = {
35+
# SDPBackend.MATH: {
36+
# "enable_math": True,
37+
# "enable_flash": False,
38+
# "enable_mem_efficient": False,
39+
# },
40+
# SDPBackend.FLASH_ATTENTION: {
41+
# "enable_math": False,
42+
# "enable_flash": True,
43+
# "enable_mem_efficient": False,
44+
# },
45+
# SDPBackend.EFFICIENT_ATTENTION: {
46+
# "enable_math": False,
47+
# "enable_flash": False,
48+
# "enable_mem_efficient": True,
49+
# },
50+
# None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
51+
# }
52+
# else:
53+
# from contextlib import nullcontext
54+
55+
# SDP_IS_AVAILABLE = False
56+
# sdpa_kernel = nullcontext
57+
# BACKEND_MAP = {}
58+
# logpy.warn(
59+
# f"No SDP backend available, likely because you are running in pytorch "
60+
# f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
61+
# f"You might want to consider upgrading."
62+
# )
6363

6464

6565
def exists(val):
@@ -211,7 +211,7 @@ def __init__(
211211
dim_head=64,
212212
dropout=0.0,
213213
# backend=None,
214-
backend=SDPBackend.FLASH_ATTENTION, # FA implemented by torch.
214+
# backend=SDPBackend.FLASH_ATTENTION, # FA implemented by torch.
215215
**kwargs,
216216
):
217217
super().__init__()
@@ -228,7 +228,7 @@ def __init__(
228228
self.to_out = nn.Sequential(
229229
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
230230
)
231-
self.backend = backend
231+
# self.backend = backend
232232

233233
def forward(
234234
self,
@@ -282,11 +282,12 @@ def forward(
282282
"""
283283
## new
284284
# with sdpa_kernel(**BACKEND_MAP[self.backend]):
285-
with sdpa_kernel([self.backend]): # new signature
286-
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
287-
out = F.scaled_dot_product_attention(
288-
q, k, v, attn_mask=mask
289-
) # scale is dim_head ** -0.5 per default
285+
# with sdpa_kernel([self.backend]): # new signature
286+
287+
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
288+
out = F.scaled_dot_product_attention(
289+
q, k, v, attn_mask=mask
290+
) # scale is dim_head ** -0.5 per default
290291

291292
del q, k, v
292293
out = rearrange(out, "b h n d -> b n (h d)", h=h)

vit/vision_transformer.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
6666
assert version.parse(torch.__version__) >= version.parse("2.0.0")
6767
SDP_IS_AVAILABLE = True
6868
# from torch.backends.cuda import SDPBackend, sdp_kernel
69-
from torch.nn.attention import sdpa_kernel, SDPBackend
69+
# from torch.nn.attention import sdpa_kernel, SDPBackend
7070

7171

7272
class Attention(nn.Module):
@@ -110,7 +110,7 @@ def __init__(self,
110110
self.no_flash_op = no_flash_op
111111
self.attn_mode = "torch"
112112

113-
self.backend = SDPBackend.FLASH_ATTENTION # FA implemented by torch.
113+
# self.backend = SDPBackend.FLASH_ATTENTION # FA implemented by torch.
114114

115115
@staticmethod
116116
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
@@ -198,9 +198,8 @@ def forward(self, x):
198198
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
199199
q, k = self.q_norm(q), self.k_norm(k)
200200

201-
with sdpa_kernel([self.backend]): # new signature
202-
203-
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
201+
# with sdpa_kernel([self.backend]): # new signature
202+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
204203

205204
del q, k, v
206205
x = rearrange(x, "B H L D -> B L (H D)")

0 commit comments

Comments
 (0)