Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/sgl-project/sglang/llms.txt

Use this file to discover all available pages before exploring further.

Scheduler

The scheduler is the core component that manages request batching, memory allocation, and execution orchestration in SGLang.

Overview

Location: python/sglang/srt/managers/scheduler.py Key Responsibilities:
  • Request queueing and prioritization
  • Dynamic batch formation
  • Memory allocation via token-to-KV pool
  • Prefix cache management (RadixAttention)
  • Request lifecycle management

Request States

A request transitions through several states:
┌─────────┐
│ Waiting │  Initial state, in waiting queue
└────┬────┘


┌─────────┐
│ Running │  Executing in a batch
└────┬────┘

     ├──→ ┌──────────┐
     │    │ Finished │  Generation complete
     │    └──────────┘

     └──→ ┌──────────┐
          │ Aborted  │  Cancelled by user
          └──────────┘

Scheduling Loop

Main Loop

def event_loop(self):
    """Main scheduling loop."""
    while True:
        # Receive requests from tokenizer manager
        recv_reqs = self.recv_requests()
        
        # Add to waiting queue
        for req in recv_reqs:
            self.waiting_queue.append(req)
        
        # Process one step
        if self.running_batch or self.waiting_queue:
            self.process_batch()
        
        # Send results to detokenizer
        self.send_results()

Batch Processing

def process_batch(self):
    """Process one batch."""
    # 1. Get next batch
    batch = self.get_next_batch()
    
    # 2. Prefill new requests
    if batch.has_prefill:
        self.run_prefill(batch)
    
    # 3. Decode existing requests
    if batch.has_decode:
        self.run_decode(batch)
    
    # 4. Sample tokens
    next_tokens = self.sample_tokens(batch)
    
    # 5. Update requests
    self.update_requests(batch, next_tokens)
    
    # 6. Check finish conditions
    self.check_finished(batch)

Dynamic Batching

Batch Formation

The scheduler dynamically forms batches based on:
  • Available memory
  • Request priorities
  • Prefill chunking constraints
def get_next_batch(self):
    """Form next batch from waiting and running requests."""
    batch = ScheduleBatch()
    
    # Add running requests (already executing)
    for req in self.running_batch:
        batch.add_decode_req(req)
    
    # Add new requests from waiting queue
    while self.waiting_queue:
        req = self.waiting_queue[0]
        
        # Check if we have memory
        if not self.can_allocate_kv_cache(req):
            break
        
        # Check if we should chunk prefill
        if self.should_chunk_prefill(req):
            chunk_size = self.get_prefill_chunk_size()
            batch.add_prefill_req(req, chunk_size)
        else:
            batch.add_prefill_req(req, len(req.input_ids))
        
        self.waiting_queue.pop(0)
        self.running_batch.append(req)
    
    return batch

Continuous Batching

Requests can join or leave batches at any time:
# Iteration 0
batch = [req1, req2, req3]  # Initial batch

# Iteration 1: req4 arrives, req2 finishes
batch = [req1, req3, req4]

# Iteration 2: req5, req6 arrive, req1 finishes
batch = [req3, req4, req5, req6]
This maximizes GPU utilization compared to static batching.

Chunked Prefill

Why Chunk?

Large prefills can block decode requests, increasing latency:
Without chunking:
[────── Long prefill (2s) ──────][Decode][Decode][Decode]
                                  ↑ High latency

With chunking:
[Prefill chunk][Decode][Prefill chunk][Decode][Prefill chunk][Decode]
                 ↑ Low latency

Implementation

def should_chunk_prefill(self, req):
    """Decide if request should be chunked."""
    return len(req.input_ids) > self.max_prefill_tokens

def get_prefill_chunk_size(self):
    """Determine chunk size based on current load."""
    if len(self.running_batch) > 10:  # Many decode requests
        return 512  # Small chunks for low latency
    else:
        return 2048  # Large chunks for high throughput

Memory Management

Token-to-KV Pool

The scheduler allocates KV cache via a memory pool:
class MemoryPool:
    def __init__(self, total_size):
        self.total_size = total_size
        self.free_blocks = [Block(0, total_size)]  # Initially all free
    
    def allocate(self, size):
        """Allocate memory block."""
        for block in self.free_blocks:
            if block.size >= size:
                # Split block
                allocated = Block(block.start, size)
                remaining = Block(block.start + size, block.size - size)
                
                self.free_blocks.remove(block)
                if remaining.size > 0:
                    self.free_blocks.append(remaining)
                
                return allocated
        
        return None  # OOM
    
    def free(self, block):
        """Free memory block."""
        self.free_blocks.append(block)
        self.merge_adjacent_blocks()  # Coalesce

Eviction Policy

When memory is full, the scheduler can evict cached prefixes:
def evict_cache(self, required_size):
    """Evict cached prefixes to free memory."""
    # LRU eviction
    candidates = sorted(
        self.cached_prefixes,
        key=lambda x: x.last_access_time
    )
    
    freed = 0
    for prefix in candidates:
        if freed >= required_size:
            break
        
        # Evict this prefix
        self.free_kv_cache(prefix)
        freed += prefix.cache_size
        
    return freed >= required_size

RadixAttention (Prefix Caching)

Radix Tree Structure

The scheduler maintains a radix tree to track shared prefixes:
class RadixNode:
    def __init__(self):
        self.children = {}  # token -> RadixNode
        self.kv_cache_indices = None  # Where KV cache is stored
        self.ref_count = 0  # Number of requests using this prefix

Prefix Matching

When a new request arrives:
def match_prefix(self, tokens):
    """Find longest matching prefix in radix tree."""
    node = self.radix_tree_root
    matched_len = 0
    
    for i, token in enumerate(tokens):
        if token in node.children:
            node = node.children[token]
            matched_len = i + 1
        else:
            break
    
    # Reuse KV cache for matched tokens
    if matched_len > 0:
        return node.kv_cache_indices[:matched_len]
    
    return None

Cache Insertion

After computing new KV cache:
def insert_cache(self, tokens, kv_indices):
    """Insert new prefix into radix tree."""
    node = self.radix_tree_root
    
    for i, token in enumerate(tokens):
        if token not in node.children:
            node.children[token] = RadixNode()
        
        node = node.children[token]
        node.kv_cache_indices = kv_indices[:i+1]
        node.ref_count += 1

Request Prioritization

Priority Levels

Requests can have different priorities:
class Priority:
    HIGH = 2
    NORMAL = 1
    LOW = 0

Scheduling with Priority

def get_next_requests(self):
    """Get next requests sorted by priority."""
    # Sort waiting queue by priority
    self.waiting_queue.sort(
        key=lambda req: (req.priority, req.arrival_time),
        reverse=True
    )
    
    # Schedule high-priority requests first
    batch = []
    for req in self.waiting_queue:
        if self.can_allocate(req):
            batch.append(req)
            if len(batch) >= self.max_batch_size:
                break
    
    return batch

Sampling

Token Sampling

After model forward pass, sample next tokens:
def sample_tokens(self, batch, logits):
    """Sample next tokens for batch."""
    next_tokens = []
    
    for i, req in enumerate(batch.reqs):
        # Get logits for this request
        req_logits = logits[i, -1, :]  # Last token
        
        # Apply penalties
        req_logits = self.apply_penalties(
            req_logits,
            req.output_ids,
            req.sampling_params
        )
        
        # Sample
        token = self.sampler.sample(
            req_logits,
            temperature=req.sampling_params.temperature,
            top_p=req.sampling_params.top_p,
            top_k=req.sampling_params.top_k,
        )
        
        next_tokens.append(token)
    
    return next_tokens

Penalties

def apply_penalties(self, logits, output_ids, params):
    """Apply frequency and presence penalties."""
    # Frequency penalty
    if params.frequency_penalty != 0:
        for token_id in output_ids:
            count = output_ids.count(token_id)
            logits[token_id] -= params.frequency_penalty * count
    
    # Presence penalty
    if params.presence_penalty != 0:
        for token_id in set(output_ids):
            logits[token_id] -= params.presence_penalty
    
    # Repetition penalty
    if params.repetition_penalty != 1.0:
        for token_id in set(output_ids):
            if logits[token_id] < 0:
                logits[token_id] *= params.repetition_penalty
            else:
                logits[token_id] /= params.repetition_penalty
    
    return logits

Finish Conditions

Checking Completion

def check_finished(self, batch):
    """Check which requests have finished."""
    finished_reqs = []
    
    for req in batch.reqs:
        # Check stop conditions
        if self.is_finished(req):
            finished_reqs.append(req)
    
    # Remove finished requests from running batch
    for req in finished_reqs:
        self.running_batch.remove(req)
        self.free_request_resources(req)
    
    return finished_reqs

def is_finished(self, req):
    """Check if request is finished."""
    # Check max tokens
    if len(req.output_ids) >= req.sampling_params.max_new_tokens:
        req.finish_reason = "length"
        return True
    
    # Check EOS token
    if req.output_ids[-1] == req.tokenizer.eos_token_id:
        if not req.sampling_params.ignore_eos:
            req.finish_reason = "stop"
            return True
    
    # Check stop strings
    if req.sampling_params.stop:
        text = req.tokenizer.decode(req.output_ids)
        for stop_str in req.sampling_params.stop:
            if stop_str in text:
                req.finish_reason = "stop"
                req.matched_stop = stop_str
                return True
    
    return False

Performance Tuning

Key Parameters

class SchedulerConfig:
    # Batch size
    max_batch_size: int = 256
    
    # Chunked prefill
    max_prefill_tokens: int = 16384  # Chunk if longer
    prefill_chunk_size: int = 512    # Chunk size
    
    # Memory
    mem_fraction_static: float = 0.9  # For model + KV cache
    
    # Radix cache
    enable_radix_cache: bool = True
    radix_cache_size: int = 1024 * 1024 * 1024  # 1GB

Monitoring

def get_stats(self):
    """Get scheduler statistics."""
    return {
        "waiting_queue_len": len(self.waiting_queue),
        "running_batch_size": len(self.running_batch),
        "cache_hit_rate": self.cache_hits / self.total_requests,
        "avg_batch_size": self.total_batch_size / self.num_batches,
        "memory_usage": self.memory_pool.used / self.memory_pool.total,
    }

Advanced Features

Speculative Decoding

Use a small draft model to speculate future tokens:
def speculative_decode(self, req):
    """Generate speculative tokens with draft model."""
    # Generate K tokens with draft model
    draft_tokens = self.draft_model.generate(req.input_ids, k=5)
    
    # Verify with target model
    target_logits = self.target_model(draft_tokens)
    
    # Accept/reject speculative tokens
    accepted = self.verify_tokens(draft_tokens, target_logits)
    
    return accepted

Multi-Model Scheduling

Schedule across multiple model replicas:
def route_request(self, req):
    """Route request to least loaded model replica."""
    # Find replica with smallest queue
    replica = min(
        self.replicas,
        key=lambda r: len(r.waiting_queue)
    )
    
    replica.add_request(req)
    return replica

Debugging

Enable Scheduler Logging

import logging
logging.getLogger("sglang.srt.managers.scheduler").setLevel(logging.DEBUG)

Trace Request

def trace_request(self, rid):
    """Trace request through scheduler."""
    logger.info(f"Request {rid} added to waiting queue")
    logger.info(f"Request {rid} matched prefix of length {matched_len}")
    logger.info(f"Request {rid} allocated {cache_size} cache")
    logger.info(f"Request {rid} started execution")
    logger.info(f"Request {rid} finished with reason: {finish_reason}")

Resources

Next Steps