Documentation Index
Fetch the complete documentation index at: https://mintlify.com/sgl-project/sglang/llms.txt
Use this file to discover all available pages before exploring further.
Kernel Development
This guide covers developing custom CUDA kernels for SGLang, including Triton kernels and CUDA C++ kernels.
Overview
SGLang uses highly optimized kernels for:
- Attention: FlashAttention, FlashInfer
- GEMM: Matrix multiplication (via cuBLAS, cutlass)
- Elementwise ops: RMSNorm, SiLU, RoPE
- Sampling: Top-k, top-p, softmax
Kernel Location:
- Triton kernels:
python/sglang/srt/layers/
- CUDA kernels:
sgl-kernel package (separate repository)
Why Custom Kernels?
Custom kernels provide:
- Performance: 2-10x speedup over PyTorch ops
- Memory efficiency: Fused operations reduce memory bandwidth
- Flexibility: Implement custom operators not in PyTorch
Triton Kernels
Introduction to Triton
Triton is a Python DSL for writing GPU kernels. It’s easier than CUDA C++ but still offers high performance.
Example: Fused RMSNorm
RMSNorm (Root Mean Square Layer Normalization) is commonly used in modern LLMs.
Unfused Implementation (PyTorch)
def rmsnorm_pytorch(x, weight, eps=1e-6):
"""RMSNorm using PyTorch ops."""
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
return x * weight
Problem: Multiple kernel launches, high memory bandwidth
Fused Triton Kernel
import triton
import triton.language as tl
@triton.jit
def rmsnorm_kernel(
x_ptr, # Pointer to input
weight_ptr, # Pointer to weight
output_ptr, # Pointer to output
stride, # Stride for input/output
N, # Hidden dimension
eps, # Epsilon for numerical stability
BLOCK_SIZE: tl.constexpr,
):
# Get program ID
pid = tl.program_id(0)
# Compute row offset
row_start = pid * stride
# Load input row
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < N
x = tl.load(x_ptr + row_start + offsets, mask=mask, other=0.0)
# Compute variance
variance = tl.sum(x * x, axis=0) / N
rstd = 1.0 / tl.sqrt(variance + eps)
# Load weight
weight = tl.load(weight_ptr + offsets, mask=mask, other=1.0)
# Normalize and scale
output = x * rstd * weight
# Store output
tl.store(output_ptr + row_start + offsets, output, mask=mask)
def rmsnorm_triton(x, weight, eps=1e-6):
"""RMSNorm using Triton kernel."""
batch_size, hidden_dim = x.shape
output = torch.empty_like(x)
# Launch kernel
BLOCK_SIZE = triton.next_power_of_2(hidden_dim)
grid = (batch_size,)
rmsnorm_kernel[grid](
x, weight, output,
stride=hidden_dim,
N=hidden_dim,
eps=eps,
BLOCK_SIZE=BLOCK_SIZE,
)
return output
Performance: ~3x faster than PyTorch
Triton Best Practices
1. Use Power-of-2 Block Sizes
BLOCK_SIZE = triton.next_power_of_2(N) # Good
BLOCK_SIZE = N # Bad if N is not power of 2
2. Coalesce Memory Accesses
# Good: Consecutive threads access consecutive memory
offsets = tl.arange(0, BLOCK_SIZE)
data = tl.load(ptr + offsets)
# Bad: Strided access
offsets = tl.arange(0, BLOCK_SIZE) * stride
data = tl.load(ptr + offsets)
3. Minimize Synchronization
# Avoid barriers if possible
tl.debug_barrier() # Use sparingly
4. Optimize Occupancy
# Tune BLOCK_SIZE for occupancy
for BLOCK_SIZE in [128, 256, 512, 1024]:
benchmark(BLOCK_SIZE)
CUDA C++ Kernels
For maximum performance, write CUDA C++ kernels in the sgl-kernel package.
Example: Fused Add + ReLU
#include <cuda_runtime.h>
#include <cuda_fp16.h>
// Kernel: Fused element-wise add and ReLU
__global__ void fused_add_relu_kernel(
const half* x,
const half* y,
half* out,
int N
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) {
half sum = __hadd(x[idx], y[idx]); // FP16 add
out[idx] = __hmax(sum, __float2half(0.0f)); // ReLU
}
}
// Host function
void fused_add_relu(
const half* x,
const half* y,
half* out,
int N
) {
int threads = 256;
int blocks = (N + threads - 1) / threads;
fused_add_relu_kernel<<<blocks, threads>>>(x, y, out, N);
}
PyTorch Binding
#include <torch/extension.h>
void fused_add_relu(
const half* x,
const half* y,
half* out,
int N
);
torch::Tensor fused_add_relu_torch(
torch::Tensor x,
torch::Tensor y
) {
auto out = torch::empty_like(x);
fused_add_relu(
reinterpret_cast<const half*>(x.data_ptr()),
reinterpret_cast<const half*>(y.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
x.numel()
);
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_add_relu", &fused_add_relu_torch);
}
Build System
# setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="my_kernels",
ext_modules=[
CUDAExtension(
"my_kernels",
["my_kernels.cu", "bindings.cpp"],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": ["-O3", "--use_fast_math"],
},
)
],
cmdclass={"build_ext": BuildExtension},
)
FlashAttention Integration
SGLang uses FlashInfer for optimized attention.
Using FlashInfer
from flashinfer import single_prefill_with_kv_cache, batch_decode_with_padded_kv_cache
# Prefill
output = single_prefill_with_kv_cache(
q=query, # [seq_len, num_heads, head_dim]
k=key_cache, # [seq_len, num_kv_heads, head_dim]
v=value_cache, # [seq_len, num_kv_heads, head_dim]
causal=True,
)
# Decode
output = batch_decode_with_padded_kv_cache(
q=query, # [batch_size, num_heads, head_dim]
k=key_cache, # [batch_size, max_seq_len, num_kv_heads, head_dim]
v=value_cache, # [batch_size, max_seq_len, num_kv_heads, head_dim]
seq_lens=seq_lens, # [batch_size]
)
Custom Attention Backend
To add a new attention backend:
- Create attention class:
# python/sglang/srt/layers/attention/my_attention.py
class MyAttention:
def __init__(self, num_heads, head_dim, num_kv_heads):
self.num_heads = num_heads
self.head_dim = head_dim
self.num_kv_heads = num_kv_heads
def forward(self, q, k, v, **kwargs):
# Implement attention
return output
- Register backend:
# python/sglang/srt/layers/attention/__init__.py
from sglang.srt.layers.attention.my_attention import MyAttention
ATTENTION_BACKENDS = {
"flashinfer": FlashInferAttention,
"flashattn": FlashAttention,
"my_backend": MyAttention,
}
- Use it:
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B \
--attention-backend my_backend
Kernel Optimization Techniques
1. Tiling
Break computation into tiles that fit in shared memory:
__global__ void matmul_tiled(
const float* A,
const float* B,
float* C,
int M, int N, int K
) {
__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
float sum = 0.0f;
// Loop over tiles
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
// Load tiles into shared memory
if (row < M && t * TILE_SIZE + threadIdx.x < K)
As[threadIdx.y][threadIdx.x] = A[row * K + t * TILE_SIZE + threadIdx.x];
else
As[threadIdx.y][threadIdx.x] = 0.0f;
if (col < N && t * TILE_SIZE + threadIdx.y < K)
Bs[threadIdx.y][threadIdx.x] = B[(t * TILE_SIZE + threadIdx.y) * N + col];
else
Bs[threadIdx.y][threadIdx.x] = 0.0f;
__syncthreads();
// Compute partial sum
for (int k = 0; k < TILE_SIZE; k++)
sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
__syncthreads();
}
if (row < M && col < N)
C[row * N + col] = sum;
}
2. Vectorized Loads
Load multiple elements per thread:
// Load 4 floats at once using float4
__global__ void vector_add(
const float* a,
const float* b,
float* c,
int N
) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
if (idx + 3 < N) {
float4 a_vec = *reinterpret_cast<const float4*>(&a[idx]);
float4 b_vec = *reinterpret_cast<const float4*>(&b[idx]);
float4 c_vec;
c_vec.x = a_vec.x + b_vec.x;
c_vec.y = a_vec.y + b_vec.y;
c_vec.z = a_vec.z + b_vec.z;
c_vec.w = a_vec.w + b_vec.w;
*reinterpret_cast<float4*>(&c[idx]) = c_vec;
}
}
3. Warp Shuffle
Communicate within a warp without shared memory:
// Warp-level reduction
__device__ float warp_reduce_sum(float val) {
for (int offset = 16; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return val;
}
__global__ void reduce_sum(
const float* input,
float* output,
int N
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float sum = (idx < N) ? input[idx] : 0.0f;
// Warp-level reduction
sum = warp_reduce_sum(sum);
// First thread in warp writes result
if (threadIdx.x % 32 == 0) {
atomicAdd(output, sum);
}
}
Profiling Kernels
Nsight Compute
# Profile specific kernel
ncu --kernel-name "my_kernel" --metrics sm__throughput.avg.pct_of_peak_sustained_elapsed \
python -m sglang.bench_one_batch --model meta-llama/Llama-3.2-1B
# Full metrics
ncu --set full -o profile python script.py
# Open in GUI
ncu-ui profile.ncu-rep
Key Metrics
- SM Throughput: Streaming Multiprocessor utilization
- Memory Throughput: DRAM bandwidth utilization
- Occupancy: Active warps / max warps
- Register Usage: Registers per thread
- Shared Memory Usage: Bytes per block
Testing Kernels
Correctness Test
import torch
import my_kernels
def test_correctness():
x = torch.randn(1024, 4096, dtype=torch.float16, device="cuda")
weight = torch.randn(4096, dtype=torch.float16, device="cuda")
# Reference (PyTorch)
ref = rmsnorm_pytorch(x, weight)
# Custom kernel
out = my_kernels.rmsnorm(x, weight)
# Check
torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-3)
print("Correctness test passed!")
test_correctness()
import time
def benchmark_kernel(fn, *args, warmup=10, iters=100):
# Warmup
for _ in range(warmup):
fn(*args)
torch.cuda.synchronize()
# Benchmark
start = time.time()
for _ in range(iters):
fn(*args)
torch.cuda.synchronize()
elapsed = time.time() - start
return elapsed / iters
# Compare
time_pytorch = benchmark_kernel(rmsnorm_pytorch, x, weight)
time_triton = benchmark_kernel(my_kernels.rmsnorm, x, weight)
print(f"PyTorch: {time_pytorch*1000:.3f} ms")
print(f"Triton: {time_triton*1000:.3f} ms")
print(f"Speedup: {time_pytorch/time_triton:.2f}x")
Adding Kernels to sgl-kernel
See Contribution Guide for the multi-PR workflow.
Step 1: Add Kernel Implementation
cd sglang/sgl-kernel
# Add your kernel
vim csrc/my_kernel.cu
Step 2: Submit PR
Submit PR to sgl-kernel without using it yet.
Step 3: Bump Version
Submit another PR to bump sgl-kernel version. This triggers PyPI release.
Step 4: Use Kernel
Update pyproject.toml in sglang and use the new kernel.
Resources
Next Steps