MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention

Microsoft Corporation, University of Surrey
{hjiang,chengzhang,yuqyang}@microsoft.com, yucheng.li@surrey.ac.uk

Now, you can process 1M context 10x faster in a single A100 using Long-context LLMs like LLaMA-3-8B-1M, GLM-4-1M, with even better accuracy, try MInference 1.0 right now!

News

  1. 🥤  [24/07/24] MInference support meta-llama/Meta-Llama-3.1-8B-Instruct now.

  2. 🪗  [24/07/07] Thanks @AK for sponsoring. You can now use MInference online in the HF Demo with ZeroGPU.

  3. 📃  [24/07/03] Due to an issue with arXiv, the PDF is currently unavailable there. You can find the paper at this link.

  4. 🧩  [24/07/03] We will present MInference 1.0 at the Microsoft Booth and ES-FoMo at ICML'24. See you in Vienna!.



Abstract

The computational challenges of LLM inference remain a significant barrier to their widespread deployment, especially as prompt lengths continue to increase. Due to the quadratic complexity of the attention computation, it takes 30 minutes for an 8B LLM to process a prompt of 1M tokens (i.e., the pre-filling stage) on a single A100 GPU. Existing methods for speeding up prefilling often fail to maintain acceptable accuracy or efficiency when applied to long-context LLMs. To address this gap, we introduce MInference, a sparse calculation method designed to accelerate pre-filling of long-sequence processing. Specifically, we identify three unique patterns in long-context attention matricesthe A-shape, Vertical-Slash, and Block-Sparse—that can be leveraged for efficient sparse computation on GPUs. We determine the optimal pattern for each attention head offline and dynamically build sparse indices based on the assigned pattern during inference. With the pattern and sparse indices, we perform efficient sparse attention calculations via our optimized GPU kernels to significantly reduce the latency in the pre-filling stage of long-context LLMs. Our proposed technique can be directly applied to existing LLMs without any modifications to the pre-training setup or additional fine-tuning. By evaluating on a wide range of downstream tasks, including InfiniteBench, RULER, PG-19, and Needle In A Haystack, and models including LLaMA-3-1M, Yi-200K, GLM-4-1M, Phi-3-128K, and Qwen2-128K, we demonstrate that MInference effectively reduces inference latency by up to 10x for pre-filling on an A100, while maintaining accuracy.



Insights

  1. Attention, especially in long-context scenarios, is sparse and dynamic, i.e., the sparse patterns are largely different across inputs.

  2. This dynamic sparsity presents three unique spatial aggregation patterns that persist for all inputs: A-shape, Vertical-Slash, and Block-Sparse.

  3. These dynamic sparse indices can be approximated with minimal overhead online and speed up attention inference using a custom optimized GPU kernel.


Why MInference?

Long-context LLM inference faces two major challenges: 1) long pre-filling stage attention latency, and 2) high storage and transfer costs for KV cache. Previous efficient methods for long-context LLMs have focused on KV-cache compression, static sparse attention (e.g., model compression, SSM, linear attention), or distributed serving. However, these methods struggle to achieve acceptable latency for million-token level prompts with low cost and a single A100 GPU.



To address these issues, we propose MInference, where the name reflects our ambition to enable million-token inference on a single A100 machine. MInference is a training-free efficient method for the pre-filling stage of long-context LLMs based on dynamic sparse attention. Specifically, we leverage the static spatial aggregation patterns of dynamic sparse attention, as shown in Fig. (3), and classify the dynamic sparse patterns into three types: A-shape, Vertical-Slash, and Block-Sparse. MInference first determines the optimal dynamic sparse pattern for each head offline using the Kernel-Aware Sparse Pattern Search algorithm, as illustrated in Alg. (1). During inference, it dynamically approximates the dynamic sparse indices based on the head's pattern, as shown in Algs. (2) and (3). Finally, we perform efficient dynamic sparse attention computation using our optimized GPU kernel, significantly reducing the pre-filling stage latency for long-context LLMs.


For example, with the Vertical-Slash pattern, we first use the attention calculation between the last Q and K to estimate the optimal indices of vertical lines and slash lines. Then, we utilize the dynamic sparse compiler PIT and Triton to construct the Vertical-Slash FlashAttention kernel, accelerating the attention computation. For the A-shape, Vertical-Slash, and Block-Sparse patterns, we first use the mean pooling of Q and K in attention calculations. By leveraging the commutative property of mean pooling and MatMul, we estimate the block-sparse indices. Then, we use Triton to construct the Block-Sparse FlashAttention kernel, accelerating the attention computation. For detailed kernel implementation, please refer to Appendix C.4 and the code.

Our main contributions are four-fold:

  1. We propose a dynamic sparse attention method, MInference, to accelerate the pre-filling stage of long-context LLMs by up to 10x for 1M token prompts while maintaining the capabilities of LLMs, especially their retrieval abilities, as demonstrated in tasks like Needle in a Haystack.

  2. We classify dynamic sparse attention in LLMs into three patterns and design the Kernel-Aware Sparse Pattern Search algorithm to find the optimal head pattern offline.

  3. We introduce an online approximate method and optimized GPU kernels to accelerate LLM inference with minimal overhead. We also propose an optimal inference codebase that enables 1M token pre-filling inference on a single A100 using LLaMA-style models.

  4. We evaluate MInference across four benchmarks: InfiniteBench, RULER, PG-19, and Needle in a Haystack, with token lengths ranging from 128k to 1M, to assess the actual context processing capabilities of LLMs. Experimental results reveal that MInference can maintain or slightly improve actual context processing capabilities, while also outperforming in terms of cost efficiency and system latency.



Experiments Results in Long-context Benchmarks

We tested MInference across a range of scenarios, including QA, coding, retrieval-based tasks, multi-hop QA, summarization, and math tasks. The RULER benchmark includes several complex multi-hop or multi-needle tasks, effectively reflecting the actual context window size of LLMs. As shown in Tab.(1), our method effectively preserves the actual context window processing capability of LLMs and even slightly extends the actual context window size to 32K.

Models Claimed Effective 4K 8K 16K 32K 64K 128K Avg.
LLaMA-3-262K 262K 16K 97.2 91.8 87.3 80.8 77.4 77.2 84.4
StreamingLLM - 4K 97.2 38.1 37.5 17.2 14.2 9.4 35.0
StreamingLLM w/ dilated - <4K 23.4 0.7 1.4 18.8 16.5 15.6 12.7
StreamingLLM w/ strided - <4K 2.0 0.7 0.6 0.6 0.7 1.3 1.0
InfLLM - 4K 89.4 79.8 70.1 55.6 43.0 39.5 62.9
MInference - 32K 97.7 91.2 88.5 85.0 82.3 77.6 87.0

Table 1. Performance (%) of different models and different methods on RULER evaluated at lengths from 4k to 128k.


We also tested MInference on a broader range of tasks using the InfiniteBench, which has an average token length of 214K, as shown in Tab.(2). Compared to the SoTA baselines, MInference consistently maintains performance across all tasks. Notably, in the more challenging retrieval tasks like KV retrieval task, all baselines fail to make accurate predictions, with accuracy rates below 1.2%. However, MInference successfully retains the ability to handle dynamic KV pair retrieval.

Models En.Sum En.QA En.MC En.Dia Zh.QA Debug Math.Find PassKey Number KV Avg.
LLaMA-3-262K 20.2 12.4 67.4 6.0 12.9 22.1 26.6 100.0 100.0 14.4 38.2
StreamingLLM 21.0 8.2 40.2 10.0 10.4 25.9 30.0 86.8 5.1 0.8 23.8
StreamingLLM w/ dilated 20.1 9.4 44.5 15.5 11.2 20.5 27.5 5.0 87.5 0.5 24.2
StreamingLLM w/ strided 17.3 8.2 27.5 14.5 11.2 19.5 27.5 4.0 2.1 1.0 13.3
InfLLM 24.1 7.8 45.0 6.0 11.4 19.5 32.9 100.0 100.0 1.2 34.8
MInference w/ static 19.9 8.6 43.2 3.5 8.9 20.6 25.1 92.4 96.3 0.2 31.9
MInference 20.5 12.9 65.9 7.5 12.5 22.3 33.1 100.0 100.0 12.8 38.8

Table 2. Performance of different methods with different base models on InfiniteBench.


To further evaluate performance across different context lengths and positions of key information within prompts, we tested various models and methods using the Needle in a Haystack task. As shown in Fig.(1), MInference performs well across different models, context windows, and positions within the prompt, maintaining or even slightly improving performance compared to the original models. In the case of LLaMA-3-8B and GLM-4-9B-1M, MInference achieves full green performance for context windows up to 1M. In comparison, StreamingLLM and InfLLM experience a performance drop to below 20% in the middle segments of prompts even in the 70K context windows.


Figure 1. Needle In A Haystack results using LLaMA-3-8B-Instruct-1M, GLM-4-9B-1M, Yi-9B-200K, Phi-3Mini-128K, and Qwen2-7-128K.


We also tested MInference on the language model tasks using PG-19, which includes tokens up to 100k. As shown in Fig.(2), MInference effectively maintains the perplexity of LLaMA-3-8B and Yi-9B-200K, while all baselines experience varying degrees of perplexity drop. Additionally, it can be observed that StreamingLLM with dilated and strided configurations better maintain perplexity performance compared to the standard StreamingLLM.


Figure 2. Perplexity results on PG-19 using different models and methods.



Latency Breakdown and Sparsity Pattern in the Kernel

Fig.(3) shows the micro-benchmark results of the three attention patterns proposed in this paper, as well as FlashAttention. It can be seen that Vertical-Slash is the slowest among the three patterns, but it still achieves a 13x speedup compared to FlashAttention under 1M context windows.


Figure 3. The latency breakdown of a single attention kernel for three patterns and FlashAttention across different context windows in a single A100, including the index time for dynamic sparse approximation and building dynamic sparsity. At 10k tokens, the latency of the four kernels is very close and all are less than 1ms. At 1M tokens, the latency for A-shape is 164ms.


Fig.(4) shows the sparse indices in the kernel of the Vertical-Slash head. The vertical lines are computed using 1x64 blocks through PIT FlashAttention, while the slash lines are computed using 64x64 blocks through Block-level FlashAttention.


Figure 4. The dynamic sparse mask in the kernel of the vertical-slash pattern schematic using LLaMA-3-8B in the summarization task, where the yellow areas indicate the parts actually involved in computation. The slash lines are covered using 64 x 64 block sizes, while the vertical lines are covered using 1 x 64 block sizes.



FAQ

Q1: How to effectively evaluate the impact of dynamic sparse attention on the capabilities of long-context LLMs?
To evaluate long-context LLM capabilities using models like LLaMA-3-8B-Instruct-1M and GLM-4-9B-1M, we tested: 1) context window with RULER, 2) general tasks with InfiniteBench, 3) retrieval tasks with Needle in a Haystack, and 4) language model prediction with PG-19.
We found traditional methods perform poorly in retrieval tasks, with difficulty levels as follows: KV retrieval > Needle in a Haystack > Retrieval.Number > Retrieval PassKey. The main challenge is the semantic difference between needles and the haystack. Traditional methods excel when this difference is larger, as in passkey tasks. KV retrieval requires higher retrieval capabilities since any key can be a target, and multi-needle tasks are even more complex.
We will continue to update our results with more models and datasets in future versions.


Q2: Does this dynamic sparse attention pattern only exist in long-context LLMs that are not fully trained?
Firstly, attention is dynamically sparse, a characteristic inherent to the mechanism. We selected state-of-the-art long-context LLMs, GLM-4-9B-1M and LLaMA-3-8B-Instruct-1M, with effective context windows of 64K and 16K. With MInference, these can be extended to 64K and 32K, respectively. We will continue to adapt our method to other advanced long-context LLMs and update our results, as well as explore the theoretical basis for this dynamic sparse attention pattern.


Q3: Does this dynamic sparse attention pattern only exist in Auto-regressive LMs or RoPE based LLMs?
Similar vertical and slash line sparse patterns have been discovered in BERT[1] and multi-modal LLMs[2]. Our analysis of T5's attention patterns, shown in the figure, reveals these patterns persist across different heads, even in bidirectional attention.
[1] SparseBERT: Rethinking the Importance Analysis in Self-Attention, ICML 2021.
[2] LOOK-M: Look-Once Optimization in KV Cache for Efficient Multimodal Long-Context Inference, 2024.

Figure 5. The sparse pattern in T5 Encoder.


Q4: What is the relationship between MInference, SSM, Linear Attention, and Sparse Attention?
All four approaches (MInference, SSM, Linear Attention, and Sparse Attention) efficiently optimize attention complexity in Transformers, each introducing inductive bias differently. The latter three require training from scratch. Recent works like Mamba-2 and Unified Implicit Attention Representation unify SSM and Linear Attention as static sparse attention, with Mamba-2 itself being a block-wise sparse method. While these approaches show potential due to sparse redundancy in attention, static sparse attention may struggle with dynamic semantic associations in complex tasks. In contrast, dynamic sparse attention is better suited for managing these relationships.



BibTeX

If you find this project helpful, please cite the following papers:

@article{jiang2024minference,
    title={MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention},
    author={Jiang, Huiqiang and Li, Yucheng and Zhang, Chengruidong and Wu, Qianhui and Luo, Xufang and Ahn, Surin and Han, Zhenhua and Abdi, Amir H and Li, Dongsheng and Lin, Chin-Yew and Yang, Yuqing and Qiu, Lili},
    journal={arXiv preprint arXiv:2407.02490},
    year={2024}
}