System Requirements
Supported TPU Hardware
| TPU Type | HBM Memory | Availability |
|---|---|---|
| TPU v6e | 32 GB | Google Cloud |
| TPU v7 | 96 GB per core | Google Cloud |
Software Requirements
- Python: 3.12 or higher
- JAX: Latest version with TPU support
- Environment: Google Cloud TPU VM or compatible TPU runtime
- Optional: SkyPilot for simplified cloud deployment
Feature Support Matrix
SGLang-JAX provides comprehensive TPU-optimized features for production LLM serving:| Feature | Support Status | Description |
|---|---|---|
| High-Throughput Continuous Batching | ✅ | Dynamic request batching for maximum TPU utilization |
| Radix Tree KV Cache | ✅ | Memory-efficient prefix sharing between requests |
| FlashAttention Backend | ✅ | TPU-optimized attention kernel for long sequences |
| Tensor Parallelism | ✅ | Distribute models across multiple TPU cores |
| Paged Attention | ✅ | Flexible KV cache management with paging |
| Speculative Decoding (EAGLE/EAGLE3) | ✅ | 20-40% throughput improvement for compatible models |
| Chunked Prefill | ✅ | Mixed prefill-decode batching |
| OpenAI-Compatible API | ✅ | Drop-in replacement for OpenAI API |
| Data Parallel Attention | 🚧 | In development - Attention computation with data parallelism |
| Quantization | 🚧 | In development - Model quantization for reduced memory usage |
| Multi-LoRA | 🚧 | In development - Serve multiple LoRA adapters simultaneously |
Attention Backend Comparison
| Backend | Paged Attention | Spec Decoding | MLA | Sliding Window |
|---|---|---|---|---|
| FlashAttention (fa) | ✅ | ✅ | ❌ | ✅ |
| Native | ❌ | ❌ | ❌ | ❌ |
Optimized Model List
The following models have been tested and optimized for TPU deployment:| Model Family | Performance Status |
|---|---|
| Qwen 3 | ⭐ Recommended for production |
| Qwen 3 MoE | ⭐ Best performance |
| Qwen 2 | Needs improvement |
| Qwen 2 MoE | Needs improvement |
| Qwen 1.5 | Needs improvement |
| Llama/LLaMA | Needs improvement |
| Grok-2 | Needs improvement |
| Gemma 2 | Verified on TPU |
| Bailing MoE | Needs improvement |
Installation
Method 1: Using PyPI (Recommended)
Method 2: From Source
Method 3: Using Docker
NOTE: Docker support for TPU is currently under development. Please use PyPI or source installation methods.Method 4: Cloud TPU with SkyPilot
SkyPilot provides simplified deployment on Google Cloud TPU:- Install SkyPilot and configure GCP access (see SkyPilot documentation)
- Create a SkyPilot configuration file:
Details
Details
SkyPilot YAML: sglang-jax.sky.yaml
- Launch your TPU cluster:
Launch of the Serving Engine
Basic Example: Qwen-7B
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache- Enables JIT compilation caching to accelerate server startup on subsequent runs--tp-size=4- Tensor parallelism size; match this to your TPU core count (typically 1, 4, or 8)--device=tpu- Specifies TPU device (this is the default for sglang-jax)--dtype=bfloat16- Uses bfloat16 precision, which TPUs are optimized for--mem-fraction-static=0.8- Allocates 80% of TPU HBM for static memory (adjustable from 0.2 to 0.9)--max-prefill-tokens=8192- Maximum number of tokens processed in the prefill phase
High-Performance Configuration: Qwen3-8B
For production workloads with optimal throughput:Advanced: Speculative Decoding (EAGLE3)
Speculative decoding can improve throughput by 20-40% for compatible models:Multi-Node Distributed Serving
For large models requiring multiple TPU VMs:Benchmarking with Requests
Throughput Testing
Basic throughput benchmark:Latency Testing
Measure single-batch latency:Comprehensive Benchmark Script
For systematic performance evaluation across different configurations:Performance Optimization
Memory Optimization
Reduce memory usage:- Lower
--mem-fraction-static(from 0.8 → 0.5 → 0.3) - Decrease
--max-prefill-tokens(from 16384 → 8192 → 4096) - Reduce
--max-running-requests
- Start with conservative memory settings (
--mem-fraction-static=0.5) - Gradually increase until you find the optimal balance
- Increase
--page-sizefor better memory locality (1 → 16 → 64 → 128)
Throughput Optimization
To maximize tokens per second:- Use FlashAttention backend:
--attention-backend=fa - Enable speculative decoding (EAGLE3) for Qwen3 models (20-40% improvement)
- Increase
--max-running-requeststo 256+ - Set
--mem-fraction-staticto 0.8+ (if memory allows) - Use larger page sizes (64-128)
- Enable chunked prefill:
--chunked-prefill-size=2048
Latency Optimization
To minimize time-to-first-token (TTFT) and inter-token latency:- Reduce
--page-sizeto 1-4 - Lower
--max-running-requests(16-32) for smaller batches - Reduce
--chunked-prefill-size - Use conservative memory settings to avoid GC pauses
TPU-Specific Optimizations
-
JIT Compilation Cache:
Always set this environment variable to cache compiled kernels and accelerate server startup.
-
Data Type Optimization:
Use
--dtype=bfloat16for TPU native optimization. TPUs are specifically designed for bfloat16 computations. -
Tensor Parallelism:
Match
--tp-sizeto your TPU core configuration (1, 4, or 8) for optimal model distribution. -
Attention Backend:
Always use
--attention-backend=fa(FlashAttention) for production workloads.
Troubleshooting
OOM (Out of Memory) Errors
If you encounter out-of-memory errors:- Reduce
--mem-fraction-staticfrom 0.8 to 0.5 or lower - Decrease
--max-prefill-tokensfrom 8192 to 4096 or 2048 - Lower
--max-running-requeststo reduce concurrent batch size - Increase
--page-sizefor better memory layout efficiency
Compilation Long-Time
If the server takes too long to start:- Ensure
JAX_COMPILATION_CACHE_DIRis properly set - Understand that the first run requires JIT compilation (this is normal)
- Subsequent runs will be significantly faster with cached compilations
- Consider using
--skip-server-warmupto defer compilation until first request
Low Throughput
If you’re not achieving expected throughput:- Verify
--tp-sizematches your TPU core configuration - Check that
--attention-backend=fais enabled - Increase
--max-running-requeststo enable larger batch formation - Consider enabling speculative decoding for compatible models
- Ensure memory settings allow for sufficient batch sizes
Connection Issues
If clients cannot connect to the server:- Ensure
--host=0.0.0.0for external access (not just127.0.0.1) - Verify firewall rules allow traffic on the specified port (default: 30000)
- Check that the server process is running:
curl http://localhost:30000/health
Advanced Features
Speculative Decoding
SGLang-JAX supports EAGLE and EAGLE3 speculative decoding algorithms for Qwen3 and LLaMA model families. Speculative decoding can improve throughput by 20-40% without affecting output quality. See the Speculative Decoding documentation for detailed configuration and supported model combinations.Chunked Prefill
Enable mixed prefill-decode batching for better TPU utilization:Custom Attention Backends
SGLang-JAX supports a plugin-based attention backend system. You can implement custom attention kernels optimized for specific use cases. See the Attention Backend documentation for implementation details.Environment Verification
Verify your TPU setup before deploying:- Installed package versions
- TPU device availability and specifications
- System resources and configuration
- Compatibility of settings
Contributing
We welcome contributions to improve TPU support in SGLang-JAX!Areas for Contribution
Check the Development Roadmap to see planned features and find opportunities to contribute new functionality. Current contribution areas include:- Performance optimizations for specific TPU generations
- Support for additional model architectures
- Documentation improvements and examples
- Bug reports and fixes
- Benchmark results and performance analysis
How to Contribute
- Visit the sglang-jax repository
- Read the Contribution Guide
- Join the SGL-JAX Slack community for discussions
- Report issues at sglang-jax/issues
Testing on TPU
For contributors who need TPU access for testing:- Refer to the TPU Resources Guide for information on accessing TPU hardware
- Use SkyPilot with spot instances for cost-effective testing
- Follow the Benchmark and Profiling Guide for performance validation
References
Documentation
- SGLang-JAX Repository
- SGLang-JAX Installation Guide
- Qwen Models Quick Start
- Benchmark and Profiling Guide
- Speculative Decoding
