Documentation Index
Fetch the complete documentation index at: https://mintlify.com/sgl-project/sglang/llms.txt
Use this file to discover all available pages before exploring further.
Overview
Pipeline Parallelism (PP) distributes model layers across multiple nodes, enabling efficient processing of ultra-long context sequences. Unlike Tensor Parallelism which requires frequent all-reduce operations, PP only communicates at layer boundaries, achieving better computation-communication overlap for multi-node deployments.
Why Pipeline Parallelism?
As LLMs scale toward trillion-parameter architectures and “infinite” context windows, serving infrastructure must evolve:
- Long context bottleneck: Ultra-long sequences create prohibitive Time to First Token (TTFT)
- Multi-node communication: TP faces bottlenecks when scaling across nodes
- Better overlap: PP communicates only at pipeline stage boundaries
- Chunked prefill: Different chunks can be processed simultaneously across nodes
Detailed analysis: Chunked Pipeline Blog
How It Works
Basic Pipeline Architecture
Node 0 (Layers 0-15) → Node 1 (Layers 16-31) → Node 2 (Layers 32-47) → Node 3 (Layers 48-63)
Each node processes a subset of layers and forwards activations to the next stage.
Dynamic Chunked Prefill
With chunked prefill, long sequences are split into chunks:
Input: [128K tokens]
↓
Chunk 1 (12K) → Node 0 → Node 1 → Node 2 → Node 3
Chunk 2 (10K) → Node 0 → Node 1 → Node 2 → Node 3
Chunk 3 (8K) → Node 0 → Node 1 → Node 2
Different chunks are processed in parallel across pipeline stages, reducing TTFT.
Asynchronous Communication
SGLang implements micro-batching with non-blocking P2P communication:
- Decoupled sync/async logic: Send operations return immediately, synchronization is deferred
- Multi-stream execution: Separate streams for forward pass, data transfers, and result processing
- Overlap computation and communication: While one micro-batch computes, the next prepares
When to Use Pipeline Parallelism
Use PP when:
- Processing ultra-long contexts (64K+ tokens)
- Scaling across multiple nodes (2-8+ nodes)
- Communication bandwidth is limited between nodes
- Working with large models (100B+ parameters)
Combine with TP when:
- Each node has multiple GPUs
- Model layers are too large for single GPU
Configuration
Basic Setup
Single Model - Multi-Node
# Node 0 (Master)
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--tp 8 \
--pp-size 4 \
--nnodes 4 \
--node-rank 0 \
--dist-init-addr <MASTER_NODE_IP>:29500 \
--chunked-prefill-size 4096
# Node 1
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--tp 8 \
--pp-size 4 \
--nnodes 4 \
--node-rank 1 \
--dist-init-addr <MASTER_NODE_IP>:29500 \
--chunked-prefill-size 4096
# Repeat for nodes 2 and 3 with appropriate node-rank
This creates a 4-stage pipeline with 8-way TP per stage (32 GPUs total).
With Dynamic Chunking
Dynamic chunking automatically adjusts chunk sizes to minimize pipeline bubbles:
export SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR=0.65
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--tp 8 \
--pp-size 4 \
--nnodes 4 \
--node-rank 0 \
--dist-init-addr <MASTER_NODE_IP>:29500 \
--chunked-prefill-size 12288 \
--enable-dynamic-chunking
Key parameters:
--chunked-prefill-size: Initial chunk size (larger when using dynamic chunking)
SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR: Controls chunk size adaptation (0.6-0.85 recommended)
Dynamic Chunking
Why Dynamic Chunking?
Fixed chunk sizes cause pipeline bubbles because:
- Transformer layers have non-uniform running time
- Longer prefix sequences take more time for same chunk size
- Bubbles propagate and accumulate across stages
How It Works
Dynamic chunking predicts optimal next chunk size to satisfy:
Runtime(L + Next Chunk Size) - Runtime(L) = Runtime(Initial Chunk Size)
Where L is the current prefix sequence length.
Algorithm:
- Model cumulative runtime as quadratic function of sequence length
- Solve for next chunk size given current prefix length L
- Align downward to nearest multiple of max(page-size, 64)
- Apply smoothing factor for stability
Tuning Dynamic Chunking
Step 1: Find Optimal Fixed Chunk Size
Test different fixed chunk sizes:
# Test 2K
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--tp 8 --pp-size 4 --nnodes 4 --node-rank 0 \
--chunked-prefill-size 2048
# Test 4K
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--tp 8 --pp-size 4 --nnodes 4 --node-rank 0 \
--chunked-prefill-size 4096
# Test 8K
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--tp 8 --pp-size 4 --nnodes 4 --node-rank 0 \
--chunked-prefill-size 8192
Measure TTFT for your target input token length.
Step 2: Set Initial Dynamic Chunk Size
Use 2-3× the optimal fixed chunk size:
# If optimal fixed size is 4K, use 12K as initial
export SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR=0.75
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--tp 8 --pp-size 4 --nnodes 4 --node-rank 0 \
--chunked-prefill-size 12288 \
--enable-dynamic-chunking
Step 3: Tune Smoothing Factor
- 1.0: Follows prediction model strictly (may create very small tail chunks)
- 0.6-0.85: Recommended range for best balance
- 0: Disables dynamic adjustment (fixed chunking)
Test different values:
# Conservative (fewer chunks, less aggressive adaptation)
export SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR=0.6
# Balanced (recommended)
export SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR=0.75
# Aggressive (more adaptation, may create smaller chunks)
export SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR=0.85
Layer Partition Optimization
For uneven layer divisions, place larger partitions on higher PP ranks:
# For DeepSeek-V3.1 with PP=4 (61 layers)
export SGLANG_PP_LAYER_PARTITION=15,15,15,16 # Better
# vs
# SGLANG_PP_LAYER_PARTITION=16,15,15,15 # Worse
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--tp 8 --pp-size 4 --nnodes 4 --node-rank 0 \
--chunked-prefill-size 12288 \
--enable-dynamic-chunking
This increases GPU utilization when higher ranks wait for previous stages.
Case Studies
DeepSeek-V3.1 (128K Context, 4×H20 Nodes)
Fixed Chunking (Baseline):
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--trust-remote-code \
--nnodes 4 --node-rank 0 \
--tp 8 --pp-size 4 \
--port 30000 \
--dist-init-addr <MASTER_NODE_IP>:29500 \
--disable-radix-cache \
--mem-fraction-static 0.8 \
--attention-backend fa3 \
--host 0.0.0.0 \
--watchdog-timeout 3600 \
--max-running-requests 128 \
--chunked-prefill-size 4096
Dynamic Chunking (Optimized):
export SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR=0.65
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--trust-remote-code \
--nnodes 4 --node-rank 0 \
--tp 8 --pp-size 4 \
--port 30000 \
--dist-init-addr <MASTER_NODE_IP>:29500 \
--disable-radix-cache \
--mem-fraction-static 0.8 \
--attention-backend fa3 \
--host 0.0.0.0 \
--watchdog-timeout 3600 \
--max-running-requests 128 \
--chunked-prefill-size 12288 \
--enable-dynamic-chunking
Qwen3-235B-A22B-FP8 (128K Context, 4×H20 Nodes)
Fixed Chunking:
python -m sglang.launch_server \
--model-path Qwen/Qwen3-235B-A22B-FP8 \
--trust-remote-code \
--nnodes 4 --node-rank 0 \
--tp 4 --pp-size 8 \
--port 30000 \
--dist-init-addr <MASTER_NODE_IP>:29500 \
--disable-radix-cache \
--mem-fraction-static 0.8 \
--attention-backend fa3 \
--host 0.0.0.0 \
--watchdog-timeout 3600 \
--max-running-requests 128 \
--chunked-prefill-size 6144
Dynamic Chunking:
export SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR=0.8
python -m sglang.launch_server \
--model-path Qwen/Qwen3-235B-A22B-FP8 \
--trust-remote-code \
--nnodes 4 --node-rank 0 \
--tp 4 --pp-size 8 \
--port 30000 \
--dist-init-addr <MASTER_NODE_IP>:29500 \
--disable-radix-cache \
--mem-fraction-static 0.8 \
--attention-backend fa3 \
--host 0.0.0.0 \
--watchdog-timeout 3600 \
--max-running-requests 128 \
--chunked-prefill-size 18432 \
--enable-dynamic-chunking
Note: --disable-radix-cache is for reproducible benchmarking only. Remove in production.
Combining with Other Parallelism
PP + TP
Most common combination for large models:
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--tp 8 \
--pp-size 4 \
--nnodes 4 \
--node-rank 0 \
--dist-init-addr <MASTER_NODE_IP>:29500 \
--chunked-prefill-size 4096
PP + TP + EP (MoE Models)
For Mixture-of-Experts models:
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--tp 8 \
--pp-size 4 \
--ep 8 \
--nnodes 4 \
--node-rank 0 \
--dist-init-addr <MASTER_NODE_IP>:29500 \
--moe-a2a-backend deepep \
--chunked-prefill-size 4096
PP + PD Disaggregation
Combine pipeline parallelism with prefill-decode disaggregation:
# Prefill instance with PP
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--disaggregation-mode prefill \
--tp 8 --pp-size 4 \
--nnodes 4 --node-rank 0 \
--dist-init-addr <PREFILL_MASTER_IP>:29500 \
--chunked-prefill-size 4096
# Decode instance with PP
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--disaggregation-mode decode \
--tp 8 --pp-size 4 \
--nnodes 4 --node-rank 0 \
--dist-init-addr <DECODE_MASTER_IP>:29500
See Prefill-Decode Disaggregation for details.
Configuration Summary
| Parameter | Description | Default | Recommended |
|---|
--pp-size | Pipeline parallel size | 1 | 2-8 for multi-node |
--chunked-prefill-size | Initial chunk size | 8192 | 4K-8K (fixed), 12K-18K (dynamic) |
--enable-dynamic-chunking | Enable dynamic chunk sizing | False | Enable for 64K+ contexts |
SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR | Chunk adaptation rate | 0.75 | 0.6-0.85 |
SGLANG_PP_LAYER_PARTITION | Manual layer distribution | Auto | ”15,15,15,16” for uneven |
--mem-fraction-static | KV cache memory | 0.9 | 0.8 for long contexts |
- Start with fixed chunking to establish baseline, then enable dynamic
- Use larger initial chunks (2-3× fixed optimal) with dynamic chunking
- Place larger partitions on higher ranks for uneven layer divisions
- Monitor pipeline bubbles using profiling tools
- Adjust smoothing factor based on your workload characteristics
Troubleshooting
High TTFT
Symptom: Long time to first token with long contexts
Solution: Enable dynamic chunking with appropriate smoothing:
export SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR=0.75
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--tp 8 --pp-size 4 \
--chunked-prefill-size 12288 \
--enable-dynamic-chunking
Pipeline Bubbles
Symptom: Low GPU utilization on some pipeline stages
Solution: Adjust layer partition:
export SGLANG_PP_LAYER_PARTITION=15,15,15,16
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--tp 8 --pp-size 4 \
--chunked-prefill-size 4096
OOM During Long Context
Symptom: Out of memory with very long sequences
Solution: Reduce chunk size and memory fraction:
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3.1 \
--tp 8 --pp-size 4 \
--chunked-prefill-size 2048 \
--mem-fraction-static 0.75
Best Practices
- Use PP for multi-node deployments over pure TP
- Combine with TP within each node for optimal performance
- Enable dynamic chunking for ultra-long contexts (64K+)
- Tune chunk sizes for your specific model and hardware
- Monitor communication overhead between pipeline stages