Skip to content

Commit

Permalink
small simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 18, 2023
1 parent 55443ab commit d73dcdd
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
10 changes: 4 additions & 6 deletions flash_attention_jax/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,20 @@ def chunk_scanner(carries, _):

block_row_max = jnp.max(attn_weights, axis = -1, keepdims = True)

exp_weights = jnp.exp(attn_weights - block_row_max)
new_row_max = jnp.maximum(block_row_max, row_max)
exp_weights = jnp.exp(attn_weights - new_row_max)

exp_weights = jnp.where(key_mask_chunk, exp_weights, 0.)
block_row_sum = jnp.sum(exp_weights, axis = -1, keepdims = True) + EPSILON

exp_values = einsum('i ... j, j ... d -> i ... d', exp_weights, v_chunk)

new_row_max = jnp.maximum(block_row_max, row_max)

exp_row_max_diff = jnp.exp(row_max - new_row_max)
exp_block_row_max_diff = jnp.exp(block_row_max - new_row_max)

new_row_sum = exp_row_max_diff * row_sum + exp_block_row_max_diff * block_row_sum
new_row_sum = exp_row_max_diff * row_sum + block_row_sum

out = (row_sum / new_row_sum) * exp_row_max_diff * out + \
(exp_block_row_max_diff / new_row_sum) * exp_values
(1. / new_row_sum) * exp_values

return (chunk_idx + k_chunk_sizes, out, new_row_sum, new_row_max), None

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'flash-attention-jax',
packages = find_packages(exclude=[]),
version = '0.3.0',
version = '0.3.1',
license='MIT',
description = 'Flash Attention - in Jax',
author = 'Phil Wang',
Expand Down

0 comments on commit d73dcdd

Please sign in to comment.