APB: Accelerating Distributed Long-Context Inference by Passing Compressed Context Blocks across GPUs
Blog | 中文博客 | Paper (ArXiV)
10x Lossless Long-Context Inference Speedup with Sequence Parallelism-Aware Approximate Attention
APB is a distributed long-context inference framework that leverages multi-host approximate attention to enhance inference speed, achieving speedups of up to 9.2x, 4.2x, and 1.6x compared to Flash Attention, Ring Attention, and Star Attention, respectively.
APB applies a tailored appoximate attention mechanism to a sequence parallelism framework. The inference process of APB is listed below sequencially.
- Context Splitting: The input sequence is evenly split to each host and prepended with an anchor block. An anchor block is the starting positions of the input sequence. Notably, the anchor block we use in APB is smaller than Star Attention.
- Block Compression: Before the attention calculation, the KV cache of each block is compressed via Locret's retaining heads.
- Communication: The compressed context block is sent to every host. Then, we construct the passing block by concatenating the compressed context block sent by the previous hosts.
- Computation: The attention is calculated on the anchor block, the passing block, and the local context block. The passing block is discarded right after attention and does not participate subsequent calculations.
conda create -n apb python=3.9
conda activate apb
pip install -r requirements.txt
pip install experiments/flash-attention-apb
pip install experiments/ring-flash-attention-main
BackBone Model | HF Repo |
---|---|
Llama-3.1-8B-instruct | Link |
Qwen-2.5-14B-instruct | Link |
Yi-34B-200K | Link |
Llama-3-8B-1M-instruct | Link |
We provide an example of using APB to process an NIAH-Simple-1-like query with 8 GPUs.
First, modify the model path, locret path, and the digits (the needle in NIAH tasks) in example/llama.sh
.
Then, run the following command.
bash example/llama.sh
The output is expected as follows (if the digits are set to 688435772345
):
Ground Truth: 688435772345
Prediction: 688435772345.
Please setup the environment first, then refer to experiments/README.md
for details.
Please cite our paper if you find our work valuable.
@article{huang2025apb,
title={APB: Accelerating Distributed Long-Context Inference by Passing Compressed Context Blocks across GPUs},
author={Huang, Yuxiang and Li, Mingye and Han, Xu and Xiao, Chaojun and Zhao, Weilin and Sun, Ao and Zhou, Hao and Zhou, Jie and Liu, Zhiyuan and Sun, Maosong},
journal={arXiv preprint arXiv:2502.12085},
year={2025}
}
The benchmark framework is partially adapted from Star Attention, MInference, RULER, and InfiniteBench. We deeply thank the authors of these code repos for their contribution to the long-context inference community.