Skip to content

Commit

Permalink
Make kv_cache read operation perform a single gather (#1027)
Browse files Browse the repository at this point in the history
Slices can merge efficiently with with the attention / GQA kernels so
its more computationally efficient single gather until gather fusion is
better.

Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
  • Loading branch information
rsuderman authored Mar 4, 2025
1 parent 9a022fb commit 8f29c7a
Showing 1 changed file with 14 additions and 31 deletions.
45 changes: 14 additions & 31 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,45 +182,28 @@ def read(
blocked_shape = [
bs,
block_seq_len,
self.cache_partition_count,
self.block_seq_stride,
self.attn_head_count // self.shard_count,
self.attn_head_dim,
]

# Reshape the page cache into sub-blocks so that we can index at the
# granularity of the transformer_block and cache partition.
# This requires us to recompute indices to the sub-block reference
# frame.
# The subblock slab is organized as:
# [page, attn_layer, cache_partition]
# Where the cache line can be 0 (k) or 1 (v).
subblock_table = page_table.flatten(start_dim=0, end_dim=2)
page_stride = self.transformer_block_count * self.cache_partition_count
transformer_block_stride = self.cache_partition_count
base_subblock_ids = page_ids * page_stride + (
transformer_block_index * transformer_block_stride
)
# Gather both partitions and split post gather. This is more
# computationally efficient without gather fusion:
subblock_table = page_table.flatten(start_dim=0, end_dim=1)
page_stride = self.transformer_block_count

def read_cache_partition(index: int):
subblock_ids = base_subblock_ids + index
# TODO: Potentially clamp all page 0 indices to the mask value.
# Or even better, require that the ids are replicated such that access is
# legal.
# Now for each of the k/v attn_block_ids, which have been adjusted to
# index into the sub-pages, we flatten to do a linear index_select
# copy of the sub-blocks by collapsing the first two dims so we have
# a linear list.
selected = (
ops.index_select(subblock_table, 0, subblock_ids.flatten(0, 1))
.unflatten(0, blocked_shape[0:2])
.flatten(1, 2)
)
return selected
transformer_block_index = torch.full(
(bs, block_seq_len), transformer_block_index
)
subblock_ids = page_ids * page_stride + transformer_block_index
selected = ops.index_select(subblock_table, 0, subblock_ids.flatten(0, 1))

key = read_cache_partition(0)
value = read_cache_partition(1)
selected = selected.unflatten(0, blocked_shape[:2])
key = selected[:, :, 0, :seq_len].flatten(1, 2)[:, :seq_len]
value = selected[:, :, 1, :seq_len].flatten(1, 2)[:, :seq_len]

return key[:, :seq_len], value[:, :seq_len]
return key, value

def write_timestep(
self,
Expand Down

0 comments on commit 8f29c7a

Please sign in to comment.