Documentation Index
Fetch the complete documentation index at: https://mintlify.com/sgl-project/sglang/llms.txt
Use this file to discover all available pages before exploring further.
Scheduler
The scheduler is the core component that manages request batching, memory allocation, and execution orchestration in SGLang.
Overview
Location: python/sglang/srt/managers/scheduler.py
Key Responsibilities:
- Request queueing and prioritization
- Dynamic batch formation
- Memory allocation via token-to-KV pool
- Prefix cache management (RadixAttention)
- Request lifecycle management
Request States
A request transitions through several states:
┌─────────┐
│ Waiting │ Initial state, in waiting queue
└────┬────┘
│
▼
┌─────────┐
│ Running │ Executing in a batch
└────┬────┘
│
├──→ ┌──────────┐
│ │ Finished │ Generation complete
│ └──────────┘
│
└──→ ┌──────────┐
│ Aborted │ Cancelled by user
└──────────┘
Scheduling Loop
Main Loop
def event_loop(self):
"""Main scheduling loop."""
while True:
# Receive requests from tokenizer manager
recv_reqs = self.recv_requests()
# Add to waiting queue
for req in recv_reqs:
self.waiting_queue.append(req)
# Process one step
if self.running_batch or self.waiting_queue:
self.process_batch()
# Send results to detokenizer
self.send_results()
Batch Processing
def process_batch(self):
"""Process one batch."""
# 1. Get next batch
batch = self.get_next_batch()
# 2. Prefill new requests
if batch.has_prefill:
self.run_prefill(batch)
# 3. Decode existing requests
if batch.has_decode:
self.run_decode(batch)
# 4. Sample tokens
next_tokens = self.sample_tokens(batch)
# 5. Update requests
self.update_requests(batch, next_tokens)
# 6. Check finish conditions
self.check_finished(batch)
Dynamic Batching
The scheduler dynamically forms batches based on:
- Available memory
- Request priorities
- Prefill chunking constraints
def get_next_batch(self):
"""Form next batch from waiting and running requests."""
batch = ScheduleBatch()
# Add running requests (already executing)
for req in self.running_batch:
batch.add_decode_req(req)
# Add new requests from waiting queue
while self.waiting_queue:
req = self.waiting_queue[0]
# Check if we have memory
if not self.can_allocate_kv_cache(req):
break
# Check if we should chunk prefill
if self.should_chunk_prefill(req):
chunk_size = self.get_prefill_chunk_size()
batch.add_prefill_req(req, chunk_size)
else:
batch.add_prefill_req(req, len(req.input_ids))
self.waiting_queue.pop(0)
self.running_batch.append(req)
return batch
Continuous Batching
Requests can join or leave batches at any time:
# Iteration 0
batch = [req1, req2, req3] # Initial batch
# Iteration 1: req4 arrives, req2 finishes
batch = [req1, req3, req4]
# Iteration 2: req5, req6 arrive, req1 finishes
batch = [req3, req4, req5, req6]
This maximizes GPU utilization compared to static batching.
Chunked Prefill
Why Chunk?
Large prefills can block decode requests, increasing latency:
Without chunking:
[────── Long prefill (2s) ──────][Decode][Decode][Decode]
↑ High latency
With chunking:
[Prefill chunk][Decode][Prefill chunk][Decode][Prefill chunk][Decode]
↑ Low latency
Implementation
def should_chunk_prefill(self, req):
"""Decide if request should be chunked."""
return len(req.input_ids) > self.max_prefill_tokens
def get_prefill_chunk_size(self):
"""Determine chunk size based on current load."""
if len(self.running_batch) > 10: # Many decode requests
return 512 # Small chunks for low latency
else:
return 2048 # Large chunks for high throughput
Memory Management
Token-to-KV Pool
The scheduler allocates KV cache via a memory pool:
class MemoryPool:
def __init__(self, total_size):
self.total_size = total_size
self.free_blocks = [Block(0, total_size)] # Initially all free
def allocate(self, size):
"""Allocate memory block."""
for block in self.free_blocks:
if block.size >= size:
# Split block
allocated = Block(block.start, size)
remaining = Block(block.start + size, block.size - size)
self.free_blocks.remove(block)
if remaining.size > 0:
self.free_blocks.append(remaining)
return allocated
return None # OOM
def free(self, block):
"""Free memory block."""
self.free_blocks.append(block)
self.merge_adjacent_blocks() # Coalesce
Eviction Policy
When memory is full, the scheduler can evict cached prefixes:
def evict_cache(self, required_size):
"""Evict cached prefixes to free memory."""
# LRU eviction
candidates = sorted(
self.cached_prefixes,
key=lambda x: x.last_access_time
)
freed = 0
for prefix in candidates:
if freed >= required_size:
break
# Evict this prefix
self.free_kv_cache(prefix)
freed += prefix.cache_size
return freed >= required_size
RadixAttention (Prefix Caching)
Radix Tree Structure
The scheduler maintains a radix tree to track shared prefixes:
class RadixNode:
def __init__(self):
self.children = {} # token -> RadixNode
self.kv_cache_indices = None # Where KV cache is stored
self.ref_count = 0 # Number of requests using this prefix
Prefix Matching
When a new request arrives:
def match_prefix(self, tokens):
"""Find longest matching prefix in radix tree."""
node = self.radix_tree_root
matched_len = 0
for i, token in enumerate(tokens):
if token in node.children:
node = node.children[token]
matched_len = i + 1
else:
break
# Reuse KV cache for matched tokens
if matched_len > 0:
return node.kv_cache_indices[:matched_len]
return None
Cache Insertion
After computing new KV cache:
def insert_cache(self, tokens, kv_indices):
"""Insert new prefix into radix tree."""
node = self.radix_tree_root
for i, token in enumerate(tokens):
if token not in node.children:
node.children[token] = RadixNode()
node = node.children[token]
node.kv_cache_indices = kv_indices[:i+1]
node.ref_count += 1
Request Prioritization
Priority Levels
Requests can have different priorities:
class Priority:
HIGH = 2
NORMAL = 1
LOW = 0
Scheduling with Priority
def get_next_requests(self):
"""Get next requests sorted by priority."""
# Sort waiting queue by priority
self.waiting_queue.sort(
key=lambda req: (req.priority, req.arrival_time),
reverse=True
)
# Schedule high-priority requests first
batch = []
for req in self.waiting_queue:
if self.can_allocate(req):
batch.append(req)
if len(batch) >= self.max_batch_size:
break
return batch
Sampling
Token Sampling
After model forward pass, sample next tokens:
def sample_tokens(self, batch, logits):
"""Sample next tokens for batch."""
next_tokens = []
for i, req in enumerate(batch.reqs):
# Get logits for this request
req_logits = logits[i, -1, :] # Last token
# Apply penalties
req_logits = self.apply_penalties(
req_logits,
req.output_ids,
req.sampling_params
)
# Sample
token = self.sampler.sample(
req_logits,
temperature=req.sampling_params.temperature,
top_p=req.sampling_params.top_p,
top_k=req.sampling_params.top_k,
)
next_tokens.append(token)
return next_tokens
Penalties
def apply_penalties(self, logits, output_ids, params):
"""Apply frequency and presence penalties."""
# Frequency penalty
if params.frequency_penalty != 0:
for token_id in output_ids:
count = output_ids.count(token_id)
logits[token_id] -= params.frequency_penalty * count
# Presence penalty
if params.presence_penalty != 0:
for token_id in set(output_ids):
logits[token_id] -= params.presence_penalty
# Repetition penalty
if params.repetition_penalty != 1.0:
for token_id in set(output_ids):
if logits[token_id] < 0:
logits[token_id] *= params.repetition_penalty
else:
logits[token_id] /= params.repetition_penalty
return logits
Finish Conditions
Checking Completion
def check_finished(self, batch):
"""Check which requests have finished."""
finished_reqs = []
for req in batch.reqs:
# Check stop conditions
if self.is_finished(req):
finished_reqs.append(req)
# Remove finished requests from running batch
for req in finished_reqs:
self.running_batch.remove(req)
self.free_request_resources(req)
return finished_reqs
def is_finished(self, req):
"""Check if request is finished."""
# Check max tokens
if len(req.output_ids) >= req.sampling_params.max_new_tokens:
req.finish_reason = "length"
return True
# Check EOS token
if req.output_ids[-1] == req.tokenizer.eos_token_id:
if not req.sampling_params.ignore_eos:
req.finish_reason = "stop"
return True
# Check stop strings
if req.sampling_params.stop:
text = req.tokenizer.decode(req.output_ids)
for stop_str in req.sampling_params.stop:
if stop_str in text:
req.finish_reason = "stop"
req.matched_stop = stop_str
return True
return False
Key Parameters
class SchedulerConfig:
# Batch size
max_batch_size: int = 256
# Chunked prefill
max_prefill_tokens: int = 16384 # Chunk if longer
prefill_chunk_size: int = 512 # Chunk size
# Memory
mem_fraction_static: float = 0.9 # For model + KV cache
# Radix cache
enable_radix_cache: bool = True
radix_cache_size: int = 1024 * 1024 * 1024 # 1GB
Monitoring
def get_stats(self):
"""Get scheduler statistics."""
return {
"waiting_queue_len": len(self.waiting_queue),
"running_batch_size": len(self.running_batch),
"cache_hit_rate": self.cache_hits / self.total_requests,
"avg_batch_size": self.total_batch_size / self.num_batches,
"memory_usage": self.memory_pool.used / self.memory_pool.total,
}
Advanced Features
Speculative Decoding
Use a small draft model to speculate future tokens:
def speculative_decode(self, req):
"""Generate speculative tokens with draft model."""
# Generate K tokens with draft model
draft_tokens = self.draft_model.generate(req.input_ids, k=5)
# Verify with target model
target_logits = self.target_model(draft_tokens)
# Accept/reject speculative tokens
accepted = self.verify_tokens(draft_tokens, target_logits)
return accepted
Multi-Model Scheduling
Schedule across multiple model replicas:
def route_request(self, req):
"""Route request to least loaded model replica."""
# Find replica with smallest queue
replica = min(
self.replicas,
key=lambda r: len(r.waiting_queue)
)
replica.add_request(req)
return replica
Debugging
Enable Scheduler Logging
import logging
logging.getLogger("sglang.srt.managers.scheduler").setLevel(logging.DEBUG)
Trace Request
def trace_request(self, rid):
"""Trace request through scheduler."""
logger.info(f"Request {rid} added to waiting queue")
logger.info(f"Request {rid} matched prefix of length {matched_len}")
logger.info(f"Request {rid} allocated {cache_size} cache")
logger.info(f"Request {rid} started execution")
logger.info(f"Request {rid} finished with reason: {finish_reason}")
Resources
Next Steps