26
26
# import apex
27
27
# from apex.normalization import FusedRMSNorm as RMSNorm
28
28
29
- if version .parse (torch .__version__ ) >= version .parse ("2.0.0" ):
30
- SDP_IS_AVAILABLE = True
31
- # from torch.backends.cuda import SDPBackend, sdp_kernel
32
- from torch .nn .attention import sdpa_kernel , SDPBackend
33
-
34
- BACKEND_MAP = {
35
- SDPBackend .MATH : {
36
- "enable_math" : True ,
37
- "enable_flash" : False ,
38
- "enable_mem_efficient" : False ,
39
- },
40
- SDPBackend .FLASH_ATTENTION : {
41
- "enable_math" : False ,
42
- "enable_flash" : True ,
43
- "enable_mem_efficient" : False ,
44
- },
45
- SDPBackend .EFFICIENT_ATTENTION : {
46
- "enable_math" : False ,
47
- "enable_flash" : False ,
48
- "enable_mem_efficient" : True ,
49
- },
50
- None : {"enable_math" : True , "enable_flash" : True , "enable_mem_efficient" : True },
51
- }
52
- else :
53
- from contextlib import nullcontext
54
-
55
- SDP_IS_AVAILABLE = False
56
- sdpa_kernel = nullcontext
57
- BACKEND_MAP = {}
58
- logpy .warn (
59
- f"No SDP backend available, likely because you are running in pytorch "
60
- f"versions < 2.0. In fact, you are using PyTorch { torch .__version__ } . "
61
- f"You might want to consider upgrading."
62
- )
29
+ # if version.parse(torch.__version__) >= version.parse("2.0.0"):
30
+ # SDP_IS_AVAILABLE = True
31
+ # # from torch.backends.cuda import SDPBackend, sdp_kernel
32
+ # from torch.nn.attention import sdpa_kernel, SDPBackend
33
+
34
+ # BACKEND_MAP = {
35
+ # SDPBackend.MATH: {
36
+ # "enable_math": True,
37
+ # "enable_flash": False,
38
+ # "enable_mem_efficient": False,
39
+ # },
40
+ # SDPBackend.FLASH_ATTENTION: {
41
+ # "enable_math": False,
42
+ # "enable_flash": True,
43
+ # "enable_mem_efficient": False,
44
+ # },
45
+ # SDPBackend.EFFICIENT_ATTENTION: {
46
+ # "enable_math": False,
47
+ # "enable_flash": False,
48
+ # "enable_mem_efficient": True,
49
+ # },
50
+ # None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
51
+ # }
52
+ # else:
53
+ # from contextlib import nullcontext
54
+
55
+ # SDP_IS_AVAILABLE = False
56
+ # sdpa_kernel = nullcontext
57
+ # BACKEND_MAP = {}
58
+ # logpy.warn(
59
+ # f"No SDP backend available, likely because you are running in pytorch "
60
+ # f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
61
+ # f"You might want to consider upgrading."
62
+ # )
63
63
64
64
65
65
def exists (val ):
@@ -211,7 +211,7 @@ def __init__(
211
211
dim_head = 64 ,
212
212
dropout = 0.0 ,
213
213
# backend=None,
214
- backend = SDPBackend .FLASH_ATTENTION , # FA implemented by torch.
214
+ # backend=SDPBackend.FLASH_ATTENTION, # FA implemented by torch.
215
215
** kwargs ,
216
216
):
217
217
super ().__init__ ()
@@ -228,7 +228,7 @@ def __init__(
228
228
self .to_out = nn .Sequential (
229
229
nn .Linear (inner_dim , query_dim ), nn .Dropout (dropout )
230
230
)
231
- self .backend = backend
231
+ # self.backend = backend
232
232
233
233
def forward (
234
234
self ,
@@ -282,11 +282,12 @@ def forward(
282
282
"""
283
283
## new
284
284
# with sdpa_kernel(**BACKEND_MAP[self.backend]):
285
- with sdpa_kernel ([self .backend ]): # new signature
286
- # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
287
- out = F .scaled_dot_product_attention (
288
- q , k , v , attn_mask = mask
289
- ) # scale is dim_head ** -0.5 per default
285
+ # with sdpa_kernel([self.backend]): # new signature
286
+
287
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
288
+ out = F .scaled_dot_product_attention (
289
+ q , k , v , attn_mask = mask
290
+ ) # scale is dim_head ** -0.5 per default
290
291
291
292
del q , k , v
292
293
out = rearrange (out , "b h n d -> b n (h d)" , h = h )
0 commit comments