Balancing Memory and Compute: Efficient Strategies for Managing KV Cache in Large Language Models
Techniques for Reducing the Memory Footprint of KV Caches Without Sacrificing Performance
Introduction
In the previous post, we introduced KV caching as a method to optimize the inference process of large language models (LLMs), reducing the compute requirements from quadratic to linear scaling with the sequence length. Specifically, KV caching involves storing the key and value tensors of past tokens in GPU memory during the generation process, thus avoiding re-computation at each step.
KV caching represents a trade-off between memory usage and compute resources. While it reduces computational load, it increases memory consumption due to the need to store cached tensors. In this post, we'll delve into the challenges posed by the growing size of the KV cache and explore common strategies to address them.
The size of the KV cache grows linearly with the batch size and the total sequence length. The per-token memory consumption depends on the precision used for storing the tensors.
Let's derive the formula for the total size of the KV cache:
Given:
Batch size: b
Total sequence length: t
Number of decoder blocks / attention layers: n_layers
Number of attention heads per attention layer: n_heads
Hidden dimension of the attention layer: d_head
Precision: p_a
The per-token memory consumption (in bytes) for the KV cache of a multi-head attention (MHA) model is:
Per-token memory consumption
The total size of the KV cache (in bytes)
This formula accounts for the fact that for each token in each sequence in the batch, we need to store two tensors (key and value) for each attention head and each attention layer.
The challenge with KV caching lies in its unbounded growth with the total sequence length, which poses difficulties in managing GPU memory, especially since the total sequence length may not be known in advance.
The graph above illustrates the memory consumption of the KV cache for a fixed batch size 'b' and total sequence length 't'.
Exploring ways to reduce memory footprint of the KV cache
Let's explore ways to reduce the memory footprint of the KV cache by examining each component of the formula:
How about reducing batch size?
While decreasing the batch size can indeed alleviate the memory footprint of the KV cache and subsequently reduce latency, it's generally not preferable. This is because reducing the batch size lowers hardware utilization, diminishing cost efficiency. In upcoming posts, we'll delve into why increasing the batch size is often more desirable.
How about minimizing reliance on the total sequence length?
To mitigate the dependency on the total sequence length, one approach is to refrain from storing keys and values for all tokens in the sequence. This strategy might involve recomputing missing keys and values on each iteration, prioritizing computational resources over GPU memory consumption, especially when memory bandwidth is a limiting factor.
Another perspective involves not storing keys and values for tokens that the model pays little or no attention to. This could be intentional in models trained to attend only to specific parts of the sequence, such as Mistral-7B, which utilizes sliding window attention (SWA) or local attention. With SWA, attention layers focus solely on neighboring tokens (only 4096), limiting the number of tensor pairs stored per sequence to the window size (4096).
Alternatively, leveraging patterns in attention distribution across the sequence can help reduce memory consumption. Attention modules often allocate more attention to a select few tokens while many tokens contribute minimally to the output. By discarding these tokens and approximating the attention matrix with a sparser representation, the impact on model accuracy, measured by metrics like perplexity, can be minimized.
Figure 1 — Example of attention (heat)map from the StreamingLLM paper: A lot of attention is consistently allocated to the first token and to the last neighboring tokens (local attention)
More methods for reducing memory consumption in language model inference without retraining or fine-tuning:
StreamingLLM Framework
Targeting models with finite-length context windows, this framework observes that initial tokens gather significant attention. It builds a sliding window by retaining only the first positional tokens ("sink tokens") and the last neighboring tokens (local attention) in the cache. The cache has a fixed length with both a fixed part and a sliding part.
H2O and Scissorhands Methods
These methods compress the KV cache by setting a maximum number of cached tokens (budget) and discarding tokens when the cache budget is reached. H2O discards one token at a time, while Scissorhands drops tokens based on a target compression ratio. Both methods exploit the observation that influential tokens at a given step tend to remain influential in future steps.
Cache Eviction Policy - Both H2O and Scissorhands employ cache eviction policies to determine which tokens to discard. Scissorhands retains the most recent tokens and tokens with the highest attention scores within a history window. H2O discards tokens with the lowest cumulated attention scores, retaining tokens consistently achieving high attention scores across iterations.
FastGen Method
FastGen focuses on preserving model accuracy by setting a maximum approximation error for the attention matrix instead of a cache budget. It profiles the model's attention layers to determine compression policies during a prefill phase. These policies, such as keeping special tokens or punctuation tokens, are applied to the KV cache at each generation step to meet the error target. If the target is too stringent, FastGen falls back to regular KV caching.
These methods have shown significant reductions in KV cache size with minimal loss in model accuracy. However, none of them are currently supported by popular language model inference frameworks.
How about reducing number of layers?
Reducing the number of layers in a language model does not offer significant gains in terms of memory reduction. Typically, smaller models naturally have fewer layers. Therefore, if a smaller model suits your use case and performs adequately, opting for it is a straightforward solution.
How about reducing the number of attention heads?
The multi-query attention (MQA) and grouped-query attention (GQA) architectures provide strategies for reducing the key-value (KV) cache size in models based on the Transformer architecture, such as those used in natural language processing tasks. These approaches allow for more efficient use of resources without sacrificing model performance significantly.
In MQA, all query heads share the same single key and value heads, meaning that each query head computes attention scores using the same keys, and all heads output values computed using the same values but different attention scores.
GQA, on the other hand, splits the query heads into groups, with each group sharing the same unique key-value heads. This allows for a smoother reduction in the number of key-value heads compared to MQA, providing a compromise between model representation capacity and KV cache size.
The parameter "g" in GQA represents the number of query head groups, allowing for a more flexible adjustment of the KV cache size. Both MQA and MHA (multi-head attention) are special cases of GQA, with g equal to 1 and the total number of heads, respectively.
KV Cache size:
These architectures have been implemented in various models by different research groups, such as Google Research's PaLM, TII's Falcon models, Meta's Llama-2 (limited to 70B only), and Mistral AI's Mistral-7B.
How about the hidden dimension of attention heads?
Once again, there is nothing much to gain here if you are not ready to opt for another model.
How about using less bytes per parameter?
Quantizing the key-value (KV) cache is an effective method for reducing its size, but it's important to use quantization algorithms that operate on both weights and activations, not just weights. Algorithms like LLM.int8() or SmoothQuant are suitable for this purpose, as they quantize both weights and activations, resulting in a reduced memory footprint.
However, for inference tasks, where memory bandwidth is the limiting factor rather than compute power, quantizing the cached tensors before moving them to GPU memory and dequantizing them afterward could suffice. This approach reduces the memory footprint without the overhead of more complex quantization algorithms.
Some inference systems, like FlexGen, NVIDIA TensorRT-LLM, and vLLM framework, already incorporate KV cache quantization features. They store the KV cache and model weights in reduced bit formats (4-bit or 8-bit) dynamically without requiring a calibration step at each iteration.