Skip to content

Commit

Permalink
Merge branch 'main' into shortfin-refactor-2
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon authored Mar 4, 2025
2 parents 356af6b + 29a1ee7 commit 9ff898e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
10 changes: 7 additions & 3 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def write_timestep(
"""
device = self.device
page_table = self.unflatten_page_table(state) # 6D
page_table = page_table.flatten(0, 3)
bs, *_ = seq_positions.shape
assert len(cache_partitions) == self.cache_partition_count

Expand Down Expand Up @@ -274,15 +275,18 @@ def write_timestep(

partitions = partitions.repeat(bs, 1)

indices = (page_id, transformer_block, partitions, page_offset)
index = page_id
index = index * self.transformer_block_count + transformer_block
index = index * self.cache_partition_count + partitions
index = index * self.block_seq_stride + page_offset
values = ops.to(cache_partition, dtype=page_table.dtype)
if page_table.dtype == torch.float8_e4m3fnuz:
# Workaround for Torch not supporting torch.Tensor.index_copy_ for f8.
page_table_as_int8 = page_table.view(dtype=torch.int8)
values_int8 = values.view(dtype=torch.int8)
page_table_as_int8.index_put_(indices=indices, values=values_int8)
page_table_as_int8.index_put_(indices=(index,), values=values_int8)
else:
page_table.index_put_(indices=indices, values=values)
page_table.index_put_(indices=(index,), values=values)

return

Expand Down
2 changes: 1 addition & 1 deletion shortfin/python/shortfin_apps/sd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ By default, the port is set to 8000. If you would like to change this, use `--po
You can check if this (or any) port is in use on Linux with `ss -ntl | grep 8000`.

```
python -m shortfin_apps.sd.server --device=amdgpu --device_ids=0 --build_preference=precompiled --topology="spx_single"
python -m shortfin_apps.sd.server --device=amdgpu --device_ids=0 --build_preference=precompiled
```
- Wait until your server outputs:
```
Expand Down

0 comments on commit 9ff898e

Please sign in to comment.