Skip to content

Commit 84a8d61

Browse files
LeymoreAli Hassani
and
Ali Hassani
authoredMar 12, 2025
Add additional KV support for flex_na (#211)
As described in the title. For unknown reason, flex_na3d with additional KV needs `eps=0.12` rather than `eps=0.1` to pass the precision check. Will come back to this later. --------- Co-authored-by: Ali Hassani <ahassani@nvidia.com>
1 parent 469df7c commit 84a8d61

9 files changed

+365
-110
lines changed
 

‎CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
* Now you can use Flex Attention instead of FNA through NATTEN directly.
1111
* Just import `use_flex_attention()` from `natten`, call it, and enjoy potentially significant
1212
speedups on newer architectures.
13+
* With support for additional KV tokens.
14+
* Better precision on fused ops with additional KV.
1315

1416

1517
## [0.17.4] - 2025-01-28

‎requirements-dev.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ mypy==1.8.0
77
pytest==7.4.4
88
click==8.1.7
99
rich
10-
xformers>=0.0.25
10+
xformers==v0.0.28.post3
1111
fvcore==0.1.5.post20221221
1212
twine

‎src/natten/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
enable_gemm_na,
3131
enable_tf32,
3232
enable_tiled_na,
33+
force_flex_attention,
3334
get_memory_usage_preference,
3435
has_bfloat,
3536
has_cuda,
@@ -57,6 +58,7 @@
5758
set_memory_usage_preference,
5859
use_autotuner,
5960
use_deterministic_algorithms,
61+
use_flex_attention,
6062
use_fna,
6163
use_fused_na,
6264
use_gemm_na,
@@ -81,6 +83,7 @@
8183
"use_fused_na",
8284
"is_fused_na_enabled",
8385
"use_autotuner",
86+
"force_flex_attention",
8487
"disable_autotuner",
8588
"is_autotuner_enabled",
8689
"is_autotuner_enabled_for_forward",
@@ -99,6 +102,7 @@
99102
"use_tf32_in_gemm_na",
100103
"use_tiled_na",
101104
"use_gemm_na",
105+
"use_flex_attention",
102106
"is_tf32_in_gemm_na_enabled",
103107
"is_tiled_na_enabled",
104108
"is_gemm_na_enabled",
@@ -114,7 +118,6 @@
114118
"disable_gemm_na",
115119
"enable_tiled_na",
116120
"disable_tiled_na",
117-
"use_flex_attention",
118121
]
119122

120123
__version__ = "0.17.5.dev0"

‎src/natten/flex.py

+80-7
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@
2323

2424
import functools
2525
import math
26-
from typing import Optional, Tuple
26+
from typing import Dict, Optional, Tuple
2727

2828
import torch
2929
from torch import BoolTensor, IntTensor, Tensor
3030
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
3131

32+
from .ops import additional_sdpa, merge_attentions
3233
from .types import (
3334
CausalArg1DTypeOrDed,
3435
CausalArg2DTypeOrDed,
@@ -129,7 +130,10 @@ def flex_na1d(
129130
kernel_size: Dimension1DTypeOrDed,
130131
dilation: Dimension1DTypeOrDed = 1,
131132
is_causal: Optional[CausalArg1DTypeOrDed] = False,
132-
) -> torch.Tensor:
133+
additional_keys: Optional[Tensor] = None,
134+
additional_values: Optional[Tensor] = None,
135+
xformers_kwargs: Optional[Dict] = None,
136+
) -> Tensor:
133137

134138
kernel_size_, dilation_, is_causal_ = check_all_args(
135139
1, kernel_size, dilation, is_causal
@@ -170,9 +174,30 @@ def flex_na1d(
170174

171175
na_mask = get_na_flex_mask(1, num_tokens_tuple, kernel_size_, dilation_, is_causal_)
172176
flex_attention_compiled = get_flex_attention_compiled()
173-
out_ = flex_attention_compiled(query_, key_, value_, block_mask=na_mask)
177+
out_, lse_ = flex_attention_compiled(
178+
query_, key_, value_, block_mask=na_mask, return_lse=True
179+
)
174180

175181
out = out_.transpose(1, 2)
182+
lse = lse_.transpose(1, 2)
183+
184+
if additional_keys is not None and additional_values is not None:
185+
if additional_keys is None or additional_values is None:
186+
raise ValueError(
187+
"Both `additional_keys` and `additional_values` must be "
188+
"either Tensors or NoneTypes."
189+
)
190+
191+
scale = query.shape[-1] ** -0.5
192+
additional_output, additional_lse = additional_sdpa(
193+
query,
194+
additional_keys,
195+
additional_values,
196+
scale=scale,
197+
attn_kwargs=xformers_kwargs,
198+
)
199+
200+
return merge_attentions(out, additional_output, lse, additional_lse)
176201

177202
return out
178203

@@ -184,7 +209,10 @@ def flex_na2d(
184209
kernel_size: Dimension2DTypeOrDed,
185210
dilation: Dimension2DTypeOrDed = 1,
186211
is_causal: Optional[CausalArg2DTypeOrDed] = False,
187-
) -> torch.Tensor:
212+
additional_keys: Optional[Tensor] = None,
213+
additional_values: Optional[Tensor] = None,
214+
xformers_kwargs: Optional[Dict] = None,
215+
) -> Tensor:
188216

189217
kernel_size_, dilation_, is_causal_ = check_all_args(
190218
2, kernel_size, dilation, is_causal
@@ -225,9 +253,30 @@ def flex_na2d(
225253

226254
na_mask = get_na_flex_mask(2, num_tokens_tuple, kernel_size_, dilation_, is_causal_)
227255
flex_attention_compiled = get_flex_attention_compiled()
228-
out_ = flex_attention_compiled(query_, key_, value_, block_mask=na_mask)
256+
out_, lse_ = flex_attention_compiled(
257+
query_, key_, value_, block_mask=na_mask, return_lse=True
258+
)
229259

230260
out = out_.transpose(1, 2).view(batch_size, *num_tokens_tuple, num_heads, head_dim)
261+
lse = lse_.transpose(1, 2).view(batch_size, *num_tokens_tuple, num_heads)
262+
263+
if additional_keys is not None and additional_values is not None:
264+
if additional_keys is None or additional_values is None:
265+
raise ValueError(
266+
"Both `additional_keys` and `additional_values` must be "
267+
"either Tensors or NoneTypes."
268+
)
269+
270+
scale = query.shape[-1] ** -0.5
271+
additional_output, additional_lse = additional_sdpa(
272+
query,
273+
additional_keys,
274+
additional_values,
275+
scale=scale,
276+
attn_kwargs=xformers_kwargs,
277+
)
278+
279+
return merge_attentions(out, additional_output, lse, additional_lse)
231280

232281
return out
233282

@@ -239,7 +288,10 @@ def flex_na3d(
239288
kernel_size: Dimension3DTypeOrDed,
240289
dilation: Dimension3DTypeOrDed = 1,
241290
is_causal: Optional[CausalArg3DTypeOrDed] = False,
242-
) -> torch.Tensor:
291+
additional_keys: Optional[Tensor] = None,
292+
additional_values: Optional[Tensor] = None,
293+
xformers_kwargs: Optional[Dict] = None,
294+
) -> Tensor:
243295

244296
kernel_size_, dilation_, is_causal_ = check_all_args(
245297
3, kernel_size, dilation, is_causal
@@ -280,8 +332,29 @@ def flex_na3d(
280332

281333
na_mask = get_na_flex_mask(3, num_tokens_tuple, kernel_size_, dilation_, is_causal_)
282334
flex_attention_compiled = get_flex_attention_compiled()
283-
out_ = flex_attention_compiled(query_, key_, value_, block_mask=na_mask)
335+
out_, lse_ = flex_attention_compiled(
336+
query_, key_, value_, block_mask=na_mask, return_lse=True
337+
)
284338

285339
out = out_.transpose(1, 2).view(batch_size, *num_tokens_tuple, num_heads, head_dim)
340+
lse = lse_.transpose(1, 2).view(batch_size, *num_tokens_tuple, num_heads)
341+
342+
if additional_keys is not None and additional_values is not None:
343+
if additional_keys is None or additional_values is None:
344+
raise ValueError(
345+
"Both `additional_keys` and `additional_values` must be "
346+
"either Tensors or NoneTypes."
347+
)
348+
349+
scale = query.shape[-1] ** -0.5
350+
additional_output, additional_lse = additional_sdpa(
351+
query,
352+
additional_keys,
353+
additional_values,
354+
scale=scale,
355+
attn_kwargs=xformers_kwargs,
356+
)
357+
358+
return merge_attentions(out, additional_output, lse, additional_lse)
286359

287360
return out

‎src/natten/functional.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -1735,10 +1735,6 @@ def na1d(
17351735
raise NotImplementedError(
17361736
"RPB is not supported in the Flex Attention backend."
17371737
)
1738-
if additional_keys is not None or additional_values is not None:
1739-
raise NotImplementedError(
1740-
"Additional keys/values is not supported in the Flex Attention backend."
1741-
)
17421738

17431739
return flex_na1d(
17441740
query,
@@ -1747,6 +1743,9 @@ def na1d(
17471743
kernel_size,
17481744
dilation,
17491745
is_causal,
1746+
additional_keys=additional_keys,
1747+
additional_values=additional_values,
1748+
xformers_kwargs=xformers_kwargs,
17501749
)
17511750

17521751
tiling_config_forward, tiling_config_backward = autotune_fna(
@@ -1817,10 +1816,6 @@ def na2d(
18171816
raise NotImplementedError(
18181817
"RPB is not supported in the Flex Attention backend."
18191818
)
1820-
if additional_keys is not None or additional_values is not None:
1821-
raise NotImplementedError(
1822-
"Additional keys/values is not supported in the Flex Attention backend."
1823-
)
18241819

18251820
return flex_na2d(
18261821
query,
@@ -1829,6 +1824,9 @@ def na2d(
18291824
kernel_size,
18301825
dilation,
18311826
is_causal,
1827+
additional_keys=additional_keys,
1828+
additional_values=additional_values,
1829+
xformers_kwargs=xformers_kwargs,
18321830
)
18331831

18341832
tiling_config_forward, tiling_config_backward = autotune_fna(
@@ -1899,10 +1897,6 @@ def na3d(
18991897
raise NotImplementedError(
19001898
"RPB is not supported in the Flex Attention backend."
19011899
)
1902-
if additional_keys is not None or additional_values is not None:
1903-
raise NotImplementedError(
1904-
"Additional keys/values is not supported in the Flex Attention backend."
1905-
)
19061900

19071901
return flex_na3d(
19081902
query,
@@ -1911,6 +1905,9 @@ def na3d(
19111905
kernel_size,
19121906
dilation,
19131907
is_causal,
1908+
additional_keys=additional_keys,
1909+
additional_values=additional_values,
1910+
xformers_kwargs=xformers_kwargs,
19141911
)
19151912

19161913
tiling_config_forward, tiling_config_backward = autotune_fna(

‎src/natten/ops.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -164,17 +164,16 @@ def merge_attentions(
164164
output_0 = output_fna.reshape(input_shape).to(accum_type)
165165
output_1 = output_sdpa.reshape(input_shape).to(accum_type)
166166

167-
sum_of_exps_0 = lse_0.exp().unsqueeze(-1).expand(*input_shape)
168-
sum_of_exps_1 = lse_1.exp().unsqueeze(-1).expand(*input_shape)
167+
lse_max = torch.maximum(lse_0, lse_1)
168+
exp_diff_0 = torch.exp(lse_0 - lse_max).unsqueeze(-1)
169+
exp_diff_1 = torch.exp(lse_1 - lse_max).unsqueeze(-1)
169170

170-
assert sum_of_exps_0.shape == sum_of_exps_1.shape == output_0.shape
171-
172-
output_0_rescaled = output_0 * sum_of_exps_0
173-
output_1_rescaled = output_1 * sum_of_exps_1
171+
output_0_rescaled = output_0 * exp_diff_0
172+
output_1_rescaled = output_1 * exp_diff_1
174173

175174
assert output_0_rescaled.shape == output_1_rescaled.shape == output_0.shape
176175

177-
sum_of_exps = sum_of_exps_0 + sum_of_exps_1
176+
sum_of_exps = exp_diff_0 + exp_diff_1
178177

179178
output = (output_0_rescaled + output_1_rescaled) / sum_of_exps
180179

0 commit comments

Comments
 (0)