Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/sgl-project/sglang/llms.txt

Use this file to discover all available pages before exploring further.

Memory Management

This guide covers SGLang’s memory management system, including KV cache allocation, radix caching, and memory optimizations.

Overview

SGLang’s memory system manages:
  • Model weights (static, loaded once)
  • KV cache (dynamic, per request)
  • Activation memory (temporary, per batch)
  • Workspace buffers (scratch space for kernels)
Location: python/sglang/srt/mem_cache/

Memory Layout

GPU Memory Breakdown

Total GPU Memory (e.g., 80GB on A100)
├── Model Weights (static) ~ 15GB
│   └── Fixed after loading
├── KV Cache Pool (dynamic) ~ 60GB
│   ├── Request 1: 2GB
│   ├── Request 2: 1.5GB
│   ├── Request 3: 3GB
│   └── Free space: 53.5GB
├── Activation Memory (temp) ~ 4GB
│   └── Reused across batches
└── Workspace Buffers ~ 1GB
    └── Scratch space for kernels

Memory Allocation at Startup

def init_memory_pool(self, config):
    """Initialize memory pools."""
    # Calculate available memory
    total_mem = torch.cuda.get_device_properties(0).total_memory
    model_mem = self.get_model_memory()
    
    # Reserve memory for KV cache
    kv_cache_mem = total_mem * config.mem_fraction_static - model_mem
    
    # Create memory pool
    self.token_to_kv_pool = TokenToKVPool(
        size=kv_cache_mem,
        token_size=self.get_kv_cache_token_size(),
    )
    
    logger.info(f"KV cache pool size: {kv_cache_mem / 1e9:.2f} GB")
    logger.info(f"Max tokens: {kv_cache_mem // self.get_kv_cache_token_size()}")

Token-to-KV Pool

Architecture

The TokenToKVPool manages KV cache allocation at token granularity.
class TokenToKVPool:
    """Manages KV cache memory at token level."""
    
    def __init__(self, size, token_size):
        self.size = size
        self.token_size = token_size  # Bytes per token's KV cache
        self.max_tokens = size // token_size
        
        # Pre-allocate GPU memory
        self.kv_data = torch.empty(
            (self.max_tokens, num_layers, 2, num_heads, head_dim),
            dtype=torch.float16,
            device="cuda"
        )
        
        # Free list of token slots
        self.free_slots = set(range(self.max_tokens))
    
    def allocate(self, num_tokens):
        """Allocate KV cache for num_tokens."""
        if len(self.free_slots) < num_tokens:
            return None  # OOM
        
        # Allocate from free list
        allocated = [self.free_slots.pop() for _ in range(num_tokens)]
        return allocated
    
    def free(self, slots):
        """Free KV cache slots."""
        self.free_slots.update(slots)

Per-Token KV Cache Size

For a model with:
  • num_layers = 32
  • num_kv_heads = 8 (GQA)
  • head_dim = 128
  • dtype = fp16 (2 bytes)
kv_cache_per_token = (
    num_layers *      # 32
    2 *               # K and V
    num_kv_heads *    # 8
    head_dim *        # 128
    2                 # bytes (fp16)
)
= 32 * 2 * 8 * 128 * 2 = 131,072 bytes = 128 KB
For 10,000 tokens: 10,000 * 128 KB = 1.28 GB

Radix Cache

Radix Tree for Prefix Sharing

Radix cache uses a tree structure to share KV cache across requests with common prefixes.
class RadixCache:
    """Radix tree for prefix caching."""
    
    class Node:
        def __init__(self):
            self.children = {}  # token_id -> Node
            self.kv_indices = []  # Slots in token_to_kv_pool
            self.ref_count = 0  # How many requests reference this
            self.last_access = time.time()
    
    def __init__(self):
        self.root = self.Node()
        self.total_nodes = 0
    
    def match(self, tokens):
        """Find longest matching prefix."""
        node = self.root
        matched_indices = []
        
        for token in tokens:
            if token in node.children:
                node = node.children[token]
                matched_indices.extend(node.kv_indices)
                node.last_access = time.time()
            else:
                break
        
        return matched_indices
    
    def insert(self, tokens, kv_indices):
        """Insert new prefix into tree."""
        node = self.root
        
        for i, token in enumerate(tokens):
            if token not in node.children:
                node.children[token] = self.Node()
                self.total_nodes += 1
            
            node = node.children[token]
            node.kv_indices = kv_indices[:i+1]
            node.ref_count += 1

Example: Prefix Sharing

# Request 1: "Translate to French: Hello"
tokens1 = [1054, 284, 2823, 25, 15496]  # "Translate to French: Hello"
kv1 = pool.allocate(len(tokens1))  # [0, 1, 2, 3, 4]
radix_cache.insert(tokens1, kv1)

# Request 2: "Translate to French: Goodbye" (shares prefix)
tokens2 = [1054, 284, 2823, 25, 7197, 29474]  # "Translate to French: Goodbye"
matched = radix_cache.match(tokens2)  # Returns [0, 1, 2, 3] (shared prefix)
remaining = len(tokens2) - len(matched)  # 2 tokens
kv2_new = pool.allocate(remaining)  # [5, 6]
kv2 = matched + kv2_new  # [0, 1, 2, 3, 5, 6]
Memory Saved: 4 tokens * 128 KB = 512 KB per request

Memory Allocation Strategies

Lazy Allocation

Allocate KV cache incrementally as tokens are generated:
class Request:
    def allocate_kv_cache(self, num_new_tokens):
        """Allocate KV cache for new tokens."""
        # Try to match prefix first
        matched = self.radix_cache.match(self.input_ids)
        
        if matched:
            self.kv_indices = matched
            num_new_tokens -= len(matched)
        
        # Allocate remaining
        if num_new_tokens > 0:
            new_indices = self.pool.allocate(num_new_tokens)
            if new_indices is None:
                raise OutOfMemoryError("KV cache pool exhausted")
            
            self.kv_indices.extend(new_indices)

Eager Eviction

Free cache immediately when request finishes:
def finish_request(self, req):
    """Clean up request resources."""
    # Decrement reference counts in radix tree
    self.radix_cache.decrement_refs(req.input_ids)
    
    # Free KV cache slots
    self.pool.free(req.kv_indices)
    
    # Remove from running batch
    self.running_batch.remove(req)

Cache Eviction Policy

When memory is full, evict least recently used (LRU) cached prefixes:
def evict_cache(self, required_slots):
    """Evict cached prefixes to free memory."""
    # Find eviction candidates (ref_count == 0)
    candidates = []
    self._collect_candidates(self.radix_cache.root, candidates)
    
    # Sort by last access time (LRU)
    candidates.sort(key=lambda node: node.last_access)
    
    # Evict until enough memory
    freed = 0
    for node in candidates:
        if freed >= required_slots:
            break
        
        # Free this node's KV cache
        self.pool.free(node.kv_indices)
        freed += len(node.kv_indices)
        
        # Remove from tree
        self._remove_node(node)
    
    return freed >= required_slots

KV Cache Formats

Contiguous Format

All tokens’ KV cache stored contiguously:
# Shape: [num_tokens, num_layers, 2, num_kv_heads, head_dim]
kv_cache = torch.zeros(
    (seq_len, num_layers, 2, num_kv_heads, head_dim),
    dtype=torch.float16,
    device="cuda"
)

# Access K for layer i, token j
k = kv_cache[j, i, 0]  # [num_kv_heads, head_dim]

# Access V for layer i, token j
v = kv_cache[j, i, 1]  # [num_kv_heads, head_dim]

Paged Format

KV cache split into fixed-size pages (e.g., PagedAttention):
page_size = 16  # tokens per page
num_pages = (seq_len + page_size - 1) // page_size

# Shape: [num_pages, page_size, num_layers, 2, num_kv_heads, head_dim]
kv_cache = torch.zeros(
    (num_pages, page_size, num_layers, 2, num_kv_heads, head_dim),
    dtype=torch.float16,
    device="cuda"
)

# Access via page table
page_table = [0, 1, 2, ...]  # Maps logical pages to physical pages
token_idx = 25
page_idx = token_idx // page_size  # 1
offset = token_idx % page_size     # 9
physical_page = page_table[page_idx]
k = kv_cache[physical_page, offset, layer_idx, 0]

Memory Optimizations

1. Quantized KV Cache

Store KV cache in lower precision:
class QuantizedKVCache:
    """INT8-quantized KV cache."""
    
    def __init__(self, *args, **kwargs):
        # Store in INT8 instead of FP16
        self.kv_data = torch.zeros(
            (max_tokens, num_layers, 2, num_kv_heads, head_dim),
            dtype=torch.int8,  # 1 byte instead of 2
            device="cuda"
        )
        # Store scale factors for dequantization
        self.scales = torch.zeros(
            (max_tokens, num_layers, 2, num_kv_heads, 1),
            dtype=torch.float16,
            device="cuda"
        )
    
    def store(self, layer_idx, kv_fp16):
        """Quantize and store KV cache."""
        # Quantize to INT8
        scale = kv_fp16.abs().max() / 127.0
        kv_int8 = (kv_fp16 / scale).round().to(torch.int8)
        
        # Store quantized values and scale
        self.kv_data[...] = kv_int8
        self.scales[...] = scale
    
    def load(self, layer_idx):
        """Dequantize and load KV cache."""
        kv_int8 = self.kv_data[...]
        scale = self.scales[...]
        
        # Dequantize to FP16
        kv_fp16 = kv_int8.to(torch.float16) * scale
        return kv_fp16
Memory Savings: 50% (INT8 vs FP16)

2. HiCache (L3 Storage)

Offload cold KV cache to CPU or SSD:
class HiCache:
    """Hierarchical cache with L1 (GPU), L2 (CPU), L3 (SSD)."""
    
    def __init__(self):
        self.l1_cache = {}  # GPU: Hot cache
        self.l2_cache = {}  # CPU: Warm cache
        self.l3_cache = {}  # SSD: Cold cache
    
    def get(self, key):
        """Get KV cache from hierarchy."""
        # Check L1 (GPU)
        if key in self.l1_cache:
            return self.l1_cache[key]
        
        # Check L2 (CPU)
        if key in self.l2_cache:
            kv = self.l2_cache[key].cuda()  # Move to GPU
            self.l1_cache[key] = kv
            return kv
        
        # Check L3 (SSD)
        if key in self.l3_cache:
            kv = torch.load(self.l3_cache[key]).cuda()
            self.l1_cache[key] = kv
            return kv
        
        return None
    
    def put(self, key, kv):
        """Store KV cache in hierarchy."""
        # Always insert to L1
        self.l1_cache[key] = kv
        
        # Evict if L1 is full
        if len(self.l1_cache) > L1_MAX_SIZE:
            # Evict LRU to L2 (CPU)
            evict_key = self._get_lru_key()
            self.l2_cache[evict_key] = self.l1_cache[evict_key].cpu()
            del self.l1_cache[evict_key]

3. Multi-Query Attention (MQA)

Reduce KV cache size by sharing across query heads:
# Standard attention
num_kv_heads = 32  # Same as num_query_heads
kv_size = 32 * head_dim * 2  # K and V

# Grouped-query attention (GQA)
num_kv_heads = 8   # Fewer than num_query_heads
kv_size = 8 * head_dim * 2   # 4x smaller

# Multi-query attention (MQA)
num_kv_heads = 1   # Single KV head
kv_size = 1 * head_dim * 2   # 32x smaller

Monitoring and Debugging

Memory Usage Statistics

def get_memory_stats(self):
    """Get memory usage statistics."""
    return {
        "total_kv_cache": self.pool.size,
        "used_kv_cache": self.pool.size - len(self.pool.free_slots) * self.pool.token_size,
        "free_kv_cache": len(self.pool.free_slots) * self.pool.token_size,
        "cache_hit_rate": self.cache_hits / self.total_requests,
        "radix_tree_nodes": self.radix_cache.total_nodes,
        "gpu_memory_allocated": torch.cuda.memory_allocated(),
        "gpu_memory_reserved": torch.cuda.memory_reserved(),
    }

Visualize Memory Usage

import matplotlib.pyplot as plt

def plot_memory_usage(stats_history):
    """Plot memory usage over time."""
    times = [s["time"] for s in stats_history]
    used = [s["used_kv_cache"] / 1e9 for s in stats_history]  # GB
    free = [s["free_kv_cache"] / 1e9 for s in stats_history]
    
    plt.figure(figsize=(10, 6))
    plt.plot(times, used, label="Used")
    plt.plot(times, free, label="Free")
    plt.xlabel("Time (s)")
    plt.ylabel("Memory (GB)")
    plt.legend()
    plt.title("KV Cache Memory Usage")
    plt.show()

Best Practices

1. Set Appropriate Memory Fraction

# Leave headroom for PyTorch overhead
python -m sglang.launch_server \
  --model-path meta-llama/Llama-3.1-8B \
  --mem-fraction-static 0.85  # 85% for model + KV cache

2. Enable RadixCache

# Enable prefix caching (enabled by default)
python -m sglang.launch_server \
  --model-path meta-llama/Llama-3.1-8B
  # RadixCache is on by default

# Disable if not beneficial
python -m sglang.launch_server \
  --model-path meta-llama/Llama-3.1-8B \
  --disable-radix-cache

3. Use Chunked Prefill

# Prevent large prefills from blocking decode
python -m sglang.launch_server \
  --model-path meta-llama/Llama-3.1-8B \
  --max-prefill-tokens 16384 \
  --prefill-chunk-size 512

Resources

Next Steps