Documentation Index
Fetch the complete documentation index at: https://mintlify.com/sgl-project/sglang/llms.txt
Use this file to discover all available pages before exploring further.
Memory Management
This guide covers SGLang’s memory management system, including KV cache allocation, radix caching, and memory optimizations.
Overview
SGLang’s memory system manages:
- Model weights (static, loaded once)
- KV cache (dynamic, per request)
- Activation memory (temporary, per batch)
- Workspace buffers (scratch space for kernels)
Location: python/sglang/srt/mem_cache/
Memory Layout
GPU Memory Breakdown
Total GPU Memory (e.g., 80GB on A100)
├── Model Weights (static) ~ 15GB
│ └── Fixed after loading
├── KV Cache Pool (dynamic) ~ 60GB
│ ├── Request 1: 2GB
│ ├── Request 2: 1.5GB
│ ├── Request 3: 3GB
│ └── Free space: 53.5GB
├── Activation Memory (temp) ~ 4GB
│ └── Reused across batches
└── Workspace Buffers ~ 1GB
└── Scratch space for kernels
Memory Allocation at Startup
def init_memory_pool(self, config):
"""Initialize memory pools."""
# Calculate available memory
total_mem = torch.cuda.get_device_properties(0).total_memory
model_mem = self.get_model_memory()
# Reserve memory for KV cache
kv_cache_mem = total_mem * config.mem_fraction_static - model_mem
# Create memory pool
self.token_to_kv_pool = TokenToKVPool(
size=kv_cache_mem,
token_size=self.get_kv_cache_token_size(),
)
logger.info(f"KV cache pool size: {kv_cache_mem / 1e9:.2f} GB")
logger.info(f"Max tokens: {kv_cache_mem // self.get_kv_cache_token_size()}")
Token-to-KV Pool
Architecture
The TokenToKVPool manages KV cache allocation at token granularity.
class TokenToKVPool:
"""Manages KV cache memory at token level."""
def __init__(self, size, token_size):
self.size = size
self.token_size = token_size # Bytes per token's KV cache
self.max_tokens = size // token_size
# Pre-allocate GPU memory
self.kv_data = torch.empty(
(self.max_tokens, num_layers, 2, num_heads, head_dim),
dtype=torch.float16,
device="cuda"
)
# Free list of token slots
self.free_slots = set(range(self.max_tokens))
def allocate(self, num_tokens):
"""Allocate KV cache for num_tokens."""
if len(self.free_slots) < num_tokens:
return None # OOM
# Allocate from free list
allocated = [self.free_slots.pop() for _ in range(num_tokens)]
return allocated
def free(self, slots):
"""Free KV cache slots."""
self.free_slots.update(slots)
Per-Token KV Cache Size
For a model with:
num_layers = 32
num_kv_heads = 8 (GQA)
head_dim = 128
dtype = fp16 (2 bytes)
kv_cache_per_token = (
num_layers * # 32
2 * # K and V
num_kv_heads * # 8
head_dim * # 128
2 # bytes (fp16)
)
= 32 * 2 * 8 * 128 * 2 = 131,072 bytes = 128 KB
For 10,000 tokens: 10,000 * 128 KB = 1.28 GB
Radix Cache
Radix Tree for Prefix Sharing
Radix cache uses a tree structure to share KV cache across requests with common prefixes.
class RadixCache:
"""Radix tree for prefix caching."""
class Node:
def __init__(self):
self.children = {} # token_id -> Node
self.kv_indices = [] # Slots in token_to_kv_pool
self.ref_count = 0 # How many requests reference this
self.last_access = time.time()
def __init__(self):
self.root = self.Node()
self.total_nodes = 0
def match(self, tokens):
"""Find longest matching prefix."""
node = self.root
matched_indices = []
for token in tokens:
if token in node.children:
node = node.children[token]
matched_indices.extend(node.kv_indices)
node.last_access = time.time()
else:
break
return matched_indices
def insert(self, tokens, kv_indices):
"""Insert new prefix into tree."""
node = self.root
for i, token in enumerate(tokens):
if token not in node.children:
node.children[token] = self.Node()
self.total_nodes += 1
node = node.children[token]
node.kv_indices = kv_indices[:i+1]
node.ref_count += 1
Example: Prefix Sharing
# Request 1: "Translate to French: Hello"
tokens1 = [1054, 284, 2823, 25, 15496] # "Translate to French: Hello"
kv1 = pool.allocate(len(tokens1)) # [0, 1, 2, 3, 4]
radix_cache.insert(tokens1, kv1)
# Request 2: "Translate to French: Goodbye" (shares prefix)
tokens2 = [1054, 284, 2823, 25, 7197, 29474] # "Translate to French: Goodbye"
matched = radix_cache.match(tokens2) # Returns [0, 1, 2, 3] (shared prefix)
remaining = len(tokens2) - len(matched) # 2 tokens
kv2_new = pool.allocate(remaining) # [5, 6]
kv2 = matched + kv2_new # [0, 1, 2, 3, 5, 6]
Memory Saved: 4 tokens * 128 KB = 512 KB per request
Memory Allocation Strategies
Lazy Allocation
Allocate KV cache incrementally as tokens are generated:
class Request:
def allocate_kv_cache(self, num_new_tokens):
"""Allocate KV cache for new tokens."""
# Try to match prefix first
matched = self.radix_cache.match(self.input_ids)
if matched:
self.kv_indices = matched
num_new_tokens -= len(matched)
# Allocate remaining
if num_new_tokens > 0:
new_indices = self.pool.allocate(num_new_tokens)
if new_indices is None:
raise OutOfMemoryError("KV cache pool exhausted")
self.kv_indices.extend(new_indices)
Eager Eviction
Free cache immediately when request finishes:
def finish_request(self, req):
"""Clean up request resources."""
# Decrement reference counts in radix tree
self.radix_cache.decrement_refs(req.input_ids)
# Free KV cache slots
self.pool.free(req.kv_indices)
# Remove from running batch
self.running_batch.remove(req)
Cache Eviction Policy
When memory is full, evict least recently used (LRU) cached prefixes:
def evict_cache(self, required_slots):
"""Evict cached prefixes to free memory."""
# Find eviction candidates (ref_count == 0)
candidates = []
self._collect_candidates(self.radix_cache.root, candidates)
# Sort by last access time (LRU)
candidates.sort(key=lambda node: node.last_access)
# Evict until enough memory
freed = 0
for node in candidates:
if freed >= required_slots:
break
# Free this node's KV cache
self.pool.free(node.kv_indices)
freed += len(node.kv_indices)
# Remove from tree
self._remove_node(node)
return freed >= required_slots
All tokens’ KV cache stored contiguously:
# Shape: [num_tokens, num_layers, 2, num_kv_heads, head_dim]
kv_cache = torch.zeros(
(seq_len, num_layers, 2, num_kv_heads, head_dim),
dtype=torch.float16,
device="cuda"
)
# Access K for layer i, token j
k = kv_cache[j, i, 0] # [num_kv_heads, head_dim]
# Access V for layer i, token j
v = kv_cache[j, i, 1] # [num_kv_heads, head_dim]
Paged Format
KV cache split into fixed-size pages (e.g., PagedAttention):
page_size = 16 # tokens per page
num_pages = (seq_len + page_size - 1) // page_size
# Shape: [num_pages, page_size, num_layers, 2, num_kv_heads, head_dim]
kv_cache = torch.zeros(
(num_pages, page_size, num_layers, 2, num_kv_heads, head_dim),
dtype=torch.float16,
device="cuda"
)
# Access via page table
page_table = [0, 1, 2, ...] # Maps logical pages to physical pages
token_idx = 25
page_idx = token_idx // page_size # 1
offset = token_idx % page_size # 9
physical_page = page_table[page_idx]
k = kv_cache[physical_page, offset, layer_idx, 0]
Memory Optimizations
1. Quantized KV Cache
Store KV cache in lower precision:
class QuantizedKVCache:
"""INT8-quantized KV cache."""
def __init__(self, *args, **kwargs):
# Store in INT8 instead of FP16
self.kv_data = torch.zeros(
(max_tokens, num_layers, 2, num_kv_heads, head_dim),
dtype=torch.int8, # 1 byte instead of 2
device="cuda"
)
# Store scale factors for dequantization
self.scales = torch.zeros(
(max_tokens, num_layers, 2, num_kv_heads, 1),
dtype=torch.float16,
device="cuda"
)
def store(self, layer_idx, kv_fp16):
"""Quantize and store KV cache."""
# Quantize to INT8
scale = kv_fp16.abs().max() / 127.0
kv_int8 = (kv_fp16 / scale).round().to(torch.int8)
# Store quantized values and scale
self.kv_data[...] = kv_int8
self.scales[...] = scale
def load(self, layer_idx):
"""Dequantize and load KV cache."""
kv_int8 = self.kv_data[...]
scale = self.scales[...]
# Dequantize to FP16
kv_fp16 = kv_int8.to(torch.float16) * scale
return kv_fp16
Memory Savings: 50% (INT8 vs FP16)
2. HiCache (L3 Storage)
Offload cold KV cache to CPU or SSD:
class HiCache:
"""Hierarchical cache with L1 (GPU), L2 (CPU), L3 (SSD)."""
def __init__(self):
self.l1_cache = {} # GPU: Hot cache
self.l2_cache = {} # CPU: Warm cache
self.l3_cache = {} # SSD: Cold cache
def get(self, key):
"""Get KV cache from hierarchy."""
# Check L1 (GPU)
if key in self.l1_cache:
return self.l1_cache[key]
# Check L2 (CPU)
if key in self.l2_cache:
kv = self.l2_cache[key].cuda() # Move to GPU
self.l1_cache[key] = kv
return kv
# Check L3 (SSD)
if key in self.l3_cache:
kv = torch.load(self.l3_cache[key]).cuda()
self.l1_cache[key] = kv
return kv
return None
def put(self, key, kv):
"""Store KV cache in hierarchy."""
# Always insert to L1
self.l1_cache[key] = kv
# Evict if L1 is full
if len(self.l1_cache) > L1_MAX_SIZE:
# Evict LRU to L2 (CPU)
evict_key = self._get_lru_key()
self.l2_cache[evict_key] = self.l1_cache[evict_key].cpu()
del self.l1_cache[evict_key]
3. Multi-Query Attention (MQA)
Reduce KV cache size by sharing across query heads:
# Standard attention
num_kv_heads = 32 # Same as num_query_heads
kv_size = 32 * head_dim * 2 # K and V
# Grouped-query attention (GQA)
num_kv_heads = 8 # Fewer than num_query_heads
kv_size = 8 * head_dim * 2 # 4x smaller
# Multi-query attention (MQA)
num_kv_heads = 1 # Single KV head
kv_size = 1 * head_dim * 2 # 32x smaller
Monitoring and Debugging
Memory Usage Statistics
def get_memory_stats(self):
"""Get memory usage statistics."""
return {
"total_kv_cache": self.pool.size,
"used_kv_cache": self.pool.size - len(self.pool.free_slots) * self.pool.token_size,
"free_kv_cache": len(self.pool.free_slots) * self.pool.token_size,
"cache_hit_rate": self.cache_hits / self.total_requests,
"radix_tree_nodes": self.radix_cache.total_nodes,
"gpu_memory_allocated": torch.cuda.memory_allocated(),
"gpu_memory_reserved": torch.cuda.memory_reserved(),
}
Visualize Memory Usage
import matplotlib.pyplot as plt
def plot_memory_usage(stats_history):
"""Plot memory usage over time."""
times = [s["time"] for s in stats_history]
used = [s["used_kv_cache"] / 1e9 for s in stats_history] # GB
free = [s["free_kv_cache"] / 1e9 for s in stats_history]
plt.figure(figsize=(10, 6))
plt.plot(times, used, label="Used")
plt.plot(times, free, label="Free")
plt.xlabel("Time (s)")
plt.ylabel("Memory (GB)")
plt.legend()
plt.title("KV Cache Memory Usage")
plt.show()
Best Practices
1. Set Appropriate Memory Fraction
# Leave headroom for PyTorch overhead
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B \
--mem-fraction-static 0.85 # 85% for model + KV cache
2. Enable RadixCache
# Enable prefix caching (enabled by default)
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B
# RadixCache is on by default
# Disable if not beneficial
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B \
--disable-radix-cache
3. Use Chunked Prefill
# Prevent large prefills from blocking decode
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B \
--max-prefill-tokens 16384 \
--prefill-chunk-size 512
Resources
Next Steps