23
23
24
24
import functools
25
25
import math
26
- from typing import Optional , Tuple
26
+ from typing import Dict , Optional , Tuple
27
27
28
28
import torch
29
29
from torch import BoolTensor , IntTensor , Tensor
30
30
from torch .nn .attention .flex_attention import create_block_mask , flex_attention
31
31
32
+ from .ops import additional_sdpa , merge_attentions
32
33
from .types import (
33
34
CausalArg1DTypeOrDed ,
34
35
CausalArg2DTypeOrDed ,
@@ -129,7 +130,10 @@ def flex_na1d(
129
130
kernel_size : Dimension1DTypeOrDed ,
130
131
dilation : Dimension1DTypeOrDed = 1 ,
131
132
is_causal : Optional [CausalArg1DTypeOrDed ] = False ,
132
- ) -> torch .Tensor :
133
+ additional_keys : Optional [Tensor ] = None ,
134
+ additional_values : Optional [Tensor ] = None ,
135
+ xformers_kwargs : Optional [Dict ] = None ,
136
+ ) -> Tensor :
133
137
134
138
kernel_size_ , dilation_ , is_causal_ = check_all_args (
135
139
1 , kernel_size , dilation , is_causal
@@ -170,9 +174,30 @@ def flex_na1d(
170
174
171
175
na_mask = get_na_flex_mask (1 , num_tokens_tuple , kernel_size_ , dilation_ , is_causal_ )
172
176
flex_attention_compiled = get_flex_attention_compiled ()
173
- out_ = flex_attention_compiled (query_ , key_ , value_ , block_mask = na_mask )
177
+ out_ , lse_ = flex_attention_compiled (
178
+ query_ , key_ , value_ , block_mask = na_mask , return_lse = True
179
+ )
174
180
175
181
out = out_ .transpose (1 , 2 )
182
+ lse = lse_ .transpose (1 , 2 )
183
+
184
+ if additional_keys is not None and additional_values is not None :
185
+ if additional_keys is None or additional_values is None :
186
+ raise ValueError (
187
+ "Both `additional_keys` and `additional_values` must be "
188
+ "either Tensors or NoneTypes."
189
+ )
190
+
191
+ scale = query .shape [- 1 ] ** - 0.5
192
+ additional_output , additional_lse = additional_sdpa (
193
+ query ,
194
+ additional_keys ,
195
+ additional_values ,
196
+ scale = scale ,
197
+ attn_kwargs = xformers_kwargs ,
198
+ )
199
+
200
+ return merge_attentions (out , additional_output , lse , additional_lse )
176
201
177
202
return out
178
203
@@ -184,7 +209,10 @@ def flex_na2d(
184
209
kernel_size : Dimension2DTypeOrDed ,
185
210
dilation : Dimension2DTypeOrDed = 1 ,
186
211
is_causal : Optional [CausalArg2DTypeOrDed ] = False ,
187
- ) -> torch .Tensor :
212
+ additional_keys : Optional [Tensor ] = None ,
213
+ additional_values : Optional [Tensor ] = None ,
214
+ xformers_kwargs : Optional [Dict ] = None ,
215
+ ) -> Tensor :
188
216
189
217
kernel_size_ , dilation_ , is_causal_ = check_all_args (
190
218
2 , kernel_size , dilation , is_causal
@@ -225,9 +253,30 @@ def flex_na2d(
225
253
226
254
na_mask = get_na_flex_mask (2 , num_tokens_tuple , kernel_size_ , dilation_ , is_causal_ )
227
255
flex_attention_compiled = get_flex_attention_compiled ()
228
- out_ = flex_attention_compiled (query_ , key_ , value_ , block_mask = na_mask )
256
+ out_ , lse_ = flex_attention_compiled (
257
+ query_ , key_ , value_ , block_mask = na_mask , return_lse = True
258
+ )
229
259
230
260
out = out_ .transpose (1 , 2 ).view (batch_size , * num_tokens_tuple , num_heads , head_dim )
261
+ lse = lse_ .transpose (1 , 2 ).view (batch_size , * num_tokens_tuple , num_heads )
262
+
263
+ if additional_keys is not None and additional_values is not None :
264
+ if additional_keys is None or additional_values is None :
265
+ raise ValueError (
266
+ "Both `additional_keys` and `additional_values` must be "
267
+ "either Tensors or NoneTypes."
268
+ )
269
+
270
+ scale = query .shape [- 1 ] ** - 0.5
271
+ additional_output , additional_lse = additional_sdpa (
272
+ query ,
273
+ additional_keys ,
274
+ additional_values ,
275
+ scale = scale ,
276
+ attn_kwargs = xformers_kwargs ,
277
+ )
278
+
279
+ return merge_attentions (out , additional_output , lse , additional_lse )
231
280
232
281
return out
233
282
@@ -239,7 +288,10 @@ def flex_na3d(
239
288
kernel_size : Dimension3DTypeOrDed ,
240
289
dilation : Dimension3DTypeOrDed = 1 ,
241
290
is_causal : Optional [CausalArg3DTypeOrDed ] = False ,
242
- ) -> torch .Tensor :
291
+ additional_keys : Optional [Tensor ] = None ,
292
+ additional_values : Optional [Tensor ] = None ,
293
+ xformers_kwargs : Optional [Dict ] = None ,
294
+ ) -> Tensor :
243
295
244
296
kernel_size_ , dilation_ , is_causal_ = check_all_args (
245
297
3 , kernel_size , dilation , is_causal
@@ -280,8 +332,29 @@ def flex_na3d(
280
332
281
333
na_mask = get_na_flex_mask (3 , num_tokens_tuple , kernel_size_ , dilation_ , is_causal_ )
282
334
flex_attention_compiled = get_flex_attention_compiled ()
283
- out_ = flex_attention_compiled (query_ , key_ , value_ , block_mask = na_mask )
335
+ out_ , lse_ = flex_attention_compiled (
336
+ query_ , key_ , value_ , block_mask = na_mask , return_lse = True
337
+ )
284
338
285
339
out = out_ .transpose (1 , 2 ).view (batch_size , * num_tokens_tuple , num_heads , head_dim )
340
+ lse = lse_ .transpose (1 , 2 ).view (batch_size , * num_tokens_tuple , num_heads )
341
+
342
+ if additional_keys is not None and additional_values is not None :
343
+ if additional_keys is None or additional_values is None :
344
+ raise ValueError (
345
+ "Both `additional_keys` and `additional_values` must be "
346
+ "either Tensors or NoneTypes."
347
+ )
348
+
349
+ scale = query .shape [- 1 ] ** - 0.5
350
+ additional_output , additional_lse = additional_sdpa (
351
+ query ,
352
+ additional_keys ,
353
+ additional_values ,
354
+ scale = scale ,
355
+ attn_kwargs = xformers_kwargs ,
356
+ )
357
+
358
+ return merge_attentions (out , additional_output , lse , additional_lse )
286
359
287
360
return out
0 commit comments