Supported Backends and Selection Guidance
SGLang’s EP integrates diverse, highly efficient backends for different use cases, allowing fine-grained control over performance trade-offs. Users specify backends via command-line flags:--moe-a2a-backend: Selects the backend for all-to-all communication.--moe-runner-backend: Selects the backend for MoE computation.
Backends for All-to-All Communication
| Backend | Description | Use Cases |
|---|---|---|
none (default) | Disables all-to-all for EP. Uses All-Reduce or All-Gather for token dispatch. | Hybrid EP and TP setups. |
deepep | DeepEP, a communication library for efficient token shuffling in MoE models. | Large-scale EP deployments. |
mooncake | An extension of DeepEP for elastic inference, leveraging RDMA for high-performance data transfers. | Elastic EP serving. |
flashinfer | Flashinfer implementation of all-to-all. | Large-scale EP deployments. |
ascend_fuseep | Ascend NPU native fused all-to-all communication. | Ascend NPU deployments. |
normal mode (optimized for prefill workloads with high throughput) and low_latency mode (optimized for decode workloads with low latency and CUDA Graph compatibility). Users are recommended to set --deepep-mode auto to enable automatic dispatch mode switching during runtime. Setting --deepep-mode normal or --deepep-mode low_latency is useful for debugging or development purposes.
Currently, DeepEP and Mooncake only support cases where ep_size = tp_size. For hybrid EP and TP (i.e., ep_size < tp_size), only the none backend (All-Reduce or All-Gather-based dispatching) is supported.
Backends for MoE Computation
| Backend | Description | Use Cases |
|---|---|---|
auto (default) | Automatically selects the optimal backend based on model architecture, hardware (e.g., NVIDIA architecture like Ampere, Hopper, Blackwell), quantization scheme (e.g., FP8, FP4), and runtime conditions. | General-purpose deployments; ensures compatibility and performance without user intervention. |
triton | Triton-based implementation for grouped GEMMs. To achieve higher performance, it’s highly recommended to create tuned configurations. | Custom kernel development or scenarios requiring high extensibility with Torch compilation support. |
deep_gemm | DeepGEMM backend optimized for MoE matrix multiplications, supporting contiguous layouts for prefill and masked layouts for decode; often JIT-compiled for performance. | Large-scale EP deployments with FP8 block-wise quantization. |
cutlass | CUTLASS-based backend for efficient GEMMs. | NVIDIA architectures with CUTLASS support. |
flashinfer_trtllm | FlashInfer integrated with TensorRT-LLM for accelerated MoE computations, supporting FP4 communication operators and high-performance GEMMs. | Blackwell with TRT-LLM. |
flashinfer_cutlass | FlashInfer combined with CUTLASS for high-performance grouped GEMMs in MoE layers, handling FP4/FP8 quantization efficiently. | Blackwell with FP4/FP8 models. |
flashinfer_mxfp4 | FlashInfer variant optimized for MXFP4 (mixed FP4) quantization in MoE runners, focusing on memory-efficient low-precision inference. | Low-precision models with MXFP4. |
flashinfer_cutedsl | FlashInfer with a custom DSL for flexible and efficient MoE kernel generation, integrated with ModelOpt FP4 quantization. | Low-precision models with NVFP4. |
Examples
Launch with DeepEP and DeepGEMM for DeepSeek-V3:Extensible EP Framework
SGLang’s EP framework provides modular abstractions for easy integration of custom kernels, backends, and optimizations. It decouples the MoE forward pass into stages (dispatch → pre-permute → core runner → post-permute → combine), enabling seamless extensions without refactoring core logic.Framework Overview
The framework centers onFusedMoE as the unified entry point for a single, extensible structure. Key components include:
- Dispatcher: Manages dispatch/combine for backends like DeepEP (implements
BaseDispatchersubclasses). - MoeRunner: Orchestrates grouped-GEMM execution via
MoeRunnerCoreimplementations (e.g.,TritonRunnerCore). - PermuteMethodPool: Auto-registers layout conversions (e.g., pre/post-permute via
register_pre_permuteandregister_post_permutefor dynamic modes, orregister_fused_funcfor static, torch.compile-compatible fused operations). - TopK Router: Backend-agnostic expert selection.
--moe-a2a-backend and --moe-runner-backend, with quantization integrated through a standardized apply() method. The computation flow ensures modularity:
Implementing New Backends
To add a new backend:- For a new all-to-all dispatcher, implement a
BaseDispatchersubclass withdispatchandcombinemethods. - For a new MoE runner backend, define a
MoeRunnerCoresubclass for core operations (e.g., grouped GEMMs). - Define new input/output formats for the dispatcher or model runner (e.g.,
RunnerInput,RunnerOutput). - Register permute/unpermute methods to ensure compatibility:
- Fused Mode (static, torch.compile-compatible): Use
register_fused_funcfor end-to-end operations. - Permute Mode (dynamic): Register
register_pre_permuteandregister_post_permutefor flexible layouts.
- Fused Mode (static, torch.compile-compatible): Use
Examples
For an example implementation, see moe_runner/triton.py, which demonstrates Triton-based grouped GEMMs with registered fused and permutation functions.Computation and Communication Overlap
SGLang’s EP employs advanced overlap techniques to hide communication latency behind computation, maximizing GPU utilization in MoE layers.Two-Batch Overlap (TBO)
TBO splits requests into micro-batches, interleaving attention computation with dispatch/combine operations. Yield points in the execution graph allow pausing for overlaps, increasing overall throughput without peak memory spikes:--enable-two-batch-overlap to unlock up to 2x throughput. For details, see the Large-Scale EP Blog.
Single-Batch Overlap (SBO)
SGLang introduces a dispatcher-hook system for Single-Batch Overlap (SBO), enabling the overlap of operations within a single batch—such as shared experts computation with communication—while decentralizing logic to enhance modularity. These hooks execute before and after thedispatch and combine operations without modifying core MoE modules. This design simplifies interfaces, reduces coupling, and improves extensibility. For implementation details and an example of overlapping shared experts with DeepEP’s combine operation, refer to PR #13327. Users can set --enable-single-batch-overlap to enable this feature.
Workload Balancer
SGLang integrates the Expert Parallelism Load Balancer (EPLB) from DeepSeek to address routing imbalances in MoE models. By analyzing expert activation statistics, EPLB computes an optimal expert arrangement, strategically placing or replicating experts to minimize GPU utilization variance, reduce idle cycles, and enhance scalability. To enable EPLB, use the flags--enable-eplb. For optimal performance, increase batch sizes to stabilize activation statistics and configure periodic rebalancing (e.g., every 1000 requests) to adapt to evolving workloads. Simulations demonstrate significant improvements in load balancedness (ratio of mean to max computation time), correlating strongly with throughput gains.
For more details, refer to the EPLB Section in the Large-Scale EP Blog and the EPLB Repository.
EP with Spectulative Decoding
When utilizing speculative decoding with MTP on MoE architectures, use the--speculative-moe-runner-backend and --speculative-moe-a2a-backend arguments to customize the MoE layer behavior for the draft model. While they default to the target model’s settings, users can differentiate them for varying precisions between target and draft models.
For model like nvidia/DeepSeek-R1-0528-NVFP4-v2, the target model uses NVFP4 precision while the draft model uses BF16. To apply flashinfer_trtllm kernel for target MoE layer while falling back to triton fused MoE kernel for draft MoE layer, users can set the arguments as follows:
Ascend NPU Guidance
Guidance on SGLang configuration in Ascend NPU
-
--moe-a2a-backendonly supports deepep and ascend_fuseep backends,-
deepep: The mechanism is consistent with the above description. -
ascend_fuseep: Offer a large fused operator which integrates all operations between dispatch and combine to boost MoE computation. Only used for decode stage in PD Disaggregation Mode.
-
-
--moe-runner-backendparameter does not need to be configured. -
--deepep-mode:-
In PD mixed mode, please set
--deepep-modeauto. -
In PD Disaggregation Mode, prefill instance sets
--deepep-modenormal, and decode instance sets--deepep-modelow_latency.
-
In PD mixed mode, please set
DeepEP Ascend Introduction
DeepEP Ascend is the adapted version of the DeepEP communication library for Huawei Ascend NPUs, specifically designed for Mixture-of-Experts (MoE) model Expert Parallelism (EP). It supports the Ant-moving Function (Split the sequence length into rounds for streaming batch transmission) to optimize the buffer size occupied during collective communication in prefill stage, especially for long sequences. Ant-moving Function can be enabled for both the dispatch and combine phases via the following environment variables:-
DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS: Enable ant-moving function in dispatch stage. Indicates the number of tokens transmitted per round on each rank, default 8192. -
DEEPEP_NORMAL_LONG_SEQ_ROUND: Enable ant-moving function in dispatch stage. Indicates the number of rounds transmitted on each rank, default 1. -
DEEPEP_NORMAL_COMBINE_ENABLE_LONG_SEQ: Enable ant-moving function in combine stage, default 0 (means disabled).
DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS * DEEPEP_NORMAL_LONG_SEQ_ROUND means input sequence length. When the input sequence length exceeds 8192, it is recommended to enable the ant-moving function in both dispatch and combine phase.
The environment variable HCCL_BUFFSIZE is used to configure the buffer size (MB) actually allocated. Its calculation formula is as follows:
-
hidden_size: hidden size in model config. -
topk: The number of selected routing experts. -
TOTAL_SEQ_LEN: input sequence length. -
PADDING_BUFFSIZE: A value of 20 or greater is recommended.
