diff --git a/docker/Dockerfile b/docker/Dockerfile index a48dc5987..c7d771e5c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -39,7 +39,7 @@ RUN pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu128/ # TE does not have wheel on cuda 13 yet, thus need to install from source RUN if [ "${ENABLE_CUDA_13}" = "1" ]; then \ - pip install nvidia-mathdx==26.6.0 && \ + pip install nvidia-mathdx==25.6.0 && \ pip -v install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.10; \ else \ pip -v install --no-build-isolation "transformer_engine[pytorch]==2.10.0"; \ diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev new file mode 100644 index 000000000..aec02efe8 --- /dev/null +++ b/docker/Dockerfile.dev @@ -0,0 +1,120 @@ +ARG SGLANG_IMAGE_TAG=v0.5.8 +FROM lmsysorg/sglang:${SGLANG_IMAGE_TAG} AS sglang + +# ======================================== Arguments ============================================= + +ARG PATCH_VERSION=v0.5.8 +ARG MEGATRON_COMMIT=3714d81d418c9f1bca4594fc35f9e8289f652862 + +ARG ENABLE_CUDA_13=0 + +# ======================================== Setup ============================================= + +WORKDIR /root/ + +# ======================================== Apt dependencies ============================================= + +RUN apt update +RUN apt install -y nvtop rsync dnsutils + +# ====================================== Python dependencies ============================================ + +# The compilation is slow, thus should be put at top +# TransformerEngines does not support too high FA2 +RUN MAX_JOBS=64 pip -v install flash-attn==2.7.4.post1 --no-build-isolation + +# The compilation is slow, thus should be put at top +RUN git clone https://github.com/Dao-AILab/flash-attention.git && \ + cd flash-attention/ && git checkout fbf24f67cf7f6442c5cfb2c1057f4bfc57e72d89 && git submodule update --init && cd hopper/ && \ + MAX_JOBS=96 python setup.py install && \ + export python_path=`python -c "import site; print(site.getsitepackages()[0])"` && \ + mkdir -p $python_path/flash_attn_3 && \ + cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py && \ + rm -rf flash-attention/ + +RUN pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps + +RUN pip install flash-linear-attention==0.4.1 +RUN pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu128/ + +# TE does not have wheel on cuda 13 yet, thus need to install from source +RUN if [ "${ENABLE_CUDA_13}" = "1" ]; then \ + pip install nvidia-mathdx==26.6.0 && \ + pip -v install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.10; \ + else \ + pip -v install --no-build-isolation "transformer_engine[pytorch]==2.10.0"; \ + fi + +RUN NVCC_APPEND_FLAGS="--threads 4" \ + pip -v install --disable-pip-version-check --no-cache-dir \ + --no-build-isolation \ + --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 8" git+https://github.com/NVIDIA/apex.git@10417aceddd7d5d05d7cbf7b0fc2daad1105f8b4 + +RUN git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \ + cd Megatron-LM && git checkout ${MEGATRON_COMMIT} && \ + pip install -e . + +RUN pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@dc6876905830430b5054325fa4211ff302169c6b --no-cache-dir --force-reinstall +RUN pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation +RUN pip install nvidia-modelopt[torch]>=0.37.0 --no-build-isolation + +# This patch from masahi will be included in later Triton releases +RUN if [ "$ENABLE_CUDA_13" = "1" ]; then \ + (cd /root && git clone -b feat/v350_plus_8045 https://github.com/fzyzcjy/triton.git && cd triton && pip install -r python/requirements.txt && pip install --verbose -e .); \ + fi + +COPY requirements.txt /tmp/requirements.txt +RUN pip install -r /tmp/requirements.txt + +# Temporarily install another sgl-kernel version for GB300 without rebuilding the whole image +RUN if [ "$ENABLE_CUDA_13" = "1" ]; then \ + SGL_KERNEL_VERSION=0.3.17.post2 && \ + python3 -m pip install https://github.com/sgl-project/whl/releases/download/v${SGL_KERNEL_VERSION}/sgl_kernel-${SGL_KERNEL_VERSION}+cu130-cp310-abi3-manylinux2014_$(uname -m).whl --force-reinstall --no-deps; \ + fi + +# https://github.com/pytorch/pytorch/issues/168167 +RUN pip install nvidia-cudnn-cu12==9.16.0.29 + +# reinstall numpy 1.x for megatron +RUN pip install "numpy<2" + +RUN rm -rf /root/.cache/pip /root/flash-attention + +# ====================================== Patches ============================================ + +COPY docker/patch/${PATCH_VERSION}/megatron.patch /root/Megatron-LM/ +RUN cd Megatron-LM && \ + git update-index --refresh && \ + git apply megatron.patch --3way && \ + if grep -R -n '^<<<<<<< ' .; then \ + echo "Patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi && \ + rm megatron.patch + +# TODO temporarily skip patching for GB200/GB300 (and require users to bring their own sglang version). should add back later. +ARG ENABLE_SGLANG_PATCH=1 +COPY docker/patch/${PATCH_VERSION}/sglang.patch /sgl-workspace/sglang/ +RUN if [ "$ENABLE_SGLANG_PATCH" = "1" ]; then \ + cd /sgl-workspace/sglang && \ + git update-index --refresh && \ + git apply sglang.patch --3way && \ + if grep -R -n '^<<<<<<< ' .; then \ + echo "Patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi && \ + rm sglang.patch; \ +fi + +# ====================================== Install main package ============================================ + +# TODO may improve +ARG CACHE_BUST=1 +ARG MILES_COMMIT=main +RUN git clone https://github.com/radixark/miles.git /root/miles && \ + cd /root/miles && \ + git checkout ${MILES_COMMIT} && \ + pip install -e . --no-deps + +RUN cd /root/miles/miles/backends/megatron_utils/kernels/int4_qat && \ + pip install . --no-build-isolation \ No newline at end of file diff --git a/docker/patch/v0.5.8/megatron.patch b/docker/patch/v0.5.8/megatron.patch new file mode 100644 index 000000000..8504c1885 --- /dev/null +++ b/docker/patch/v0.5.8/megatron.patch @@ -0,0 +1,772 @@ +diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py +index 41c21d93d..ef80f72d6 100644 +--- a/megatron/core/dist_checkpointing/strategies/common.py ++++ b/megatron/core/dist_checkpointing/strategies/common.py +@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): + msc = MultiStorageClientFeature.import_package() + return msc.torch.load(load_path, map_location='cpu') + else: +- return torch.load(load_path, map_location='cpu') ++ return torch.load(load_path, map_location='cpu', weights_only=False) + except FileNotFoundError as e: + err_msg = f'Common file {load_path} does not exist' + if MultiStorageClientFeature.is_enabled(): +diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py +index 5a1ea308d..aa701237f 100644 +--- a/megatron/core/dist_checkpointing/strategies/torch.py ++++ b/megatron/core/dist_checkpointing/strategies/torch.py +@@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + def _validate_global_shapes(self, metadata, sharded_tensors): + for sh_ten in sharded_tensors: + if sh_ten.key not in metadata.state_dict_metadata: +- raise KeyError( +- f"{sh_ten.key} from model not in state dict:" +- f" {sorted(metadata.state_dict_metadata.keys())}" +- ) ++ # raise KeyError( ++ # f"{sh_ten.key} from model not in state dict:" ++ # f" {sorted(metadata.state_dict_metadata.keys())}" ++ # ) ++ print(f"{sh_ten.key} from model not in state dict, will skip") ++ continue + loaded_shape = metadata.state_dict_metadata[sh_ten.key].size + expected_shape = self._expected_shape(sh_ten) + if loaded_shape != expected_shape: +@@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + tensor_metadata = self.metadata.state_dict_metadata + metadata_with_sizes = [ + (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) +- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() ++ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata + ] + try: + # Temporarily set sizes to expected shapes +@@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): + planner=MCoreLoadPlanner( + shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, + allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, ++ allow_partial_load=True, + ), + ) + +diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py +index acb93ef78..d239db4ab 100644 +--- a/megatron/core/extensions/transformer_engine.py ++++ b/megatron/core/extensions/transformer_engine.py +@@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear): + ) + + for param in self.parameters(): ++ setattr(param, "parallel_mode", parallel_mode) + if is_expert: + # Reduce the gradient on the expert_data_parallel group for expert linear layers + setattr(param, "allreduce", not self.expert_parallel) +@@ -1161,6 +1162,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): + + + if HAVE_TE and is_te_min_version("1.9.0.dev0"): ++ def ceil_div(x: int, y: int) -> int: ++ return (x + y - 1) // y ++ ++ class _FakeInt4QuantizationSTE(torch.autograd.Function): ++ @staticmethod ++ def forward(ctx, x, group_size): ++ m, n = x.shape ++ block_size_m, block_size_n = 1, group_size ++ ++ ++ m_padded = ceil_div(m, block_size_m) * block_size_m ++ n_padded = ceil_div(n, block_size_n) * block_size_n ++ ++ x_padded = torch.zeros( ++ (m_padded, n_padded), ++ dtype=x.dtype, device=x.device ++ ) ++ x_padded[:m, :n] = x ++ ++ x_view = x_padded.view( ++ m_padded // block_size_m, ++ block_size_m, ++ n_padded // block_size_n, ++ block_size_n ++ ) ++ ++ x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True) ++ q_max = 7 ++ x_scale = x_max / q_max ++ ++ x_scale = x_scale.clamp(min=1e-5) ++ ++ x_div = x_view / x_scale ++ x_round = torch.round(x_div) ++ ++ x_q_clamped = x_round.clamp(-q_max, q_max) ++ ++ x_dequant_view = x_q_clamped * x_scale ++ ++ x_dequant_full = x_dequant_view.view_as(x_padded) ++ x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype) ++ ++ return x_out ++ ++ @staticmethod ++ def backward(ctx, grad_output): ++ return grad_output, None ++ ++ def fake_int4_quantization_ste(x, group_size): ++ x_out = _FakeInt4QuantizationSTE.apply(x, group_size) ++ ++ if hasattr(x, 'main_grad'): ++ x_out.main_grad = x.main_grad ++ ++ return x_out + + class TEGroupedLinear(te.pytorch.GroupedLinear): + """ +@@ -1351,6 +1407,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) ++ + out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + +@@ -1361,6 +1418,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + return out + return out, None + ++ def _get_weight_tensors(self): ++ """Get the weight tensors of the module.""" ++ weight_tensors = super()._get_weight_tensors() ++ ++ if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1": ++ group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128")) ++ ++ weight_tensors = [ ++ fake_int4_quantization_ste(w, group_size) ++ for w in weight_tensors ++ ] ++ ++ return weight_tensors ++ + def _encode_extra_state(self, state): + # TE 2.0 changed the format of extra_state to be a byte tensor + if is_te_min_version("2.0.0"): +diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +index 1fd5dcfae..c9aeef1f0 100644 +--- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py ++++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( + cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + +- KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads +- kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads +- mask = kv_off < head_num * stride_kv_nheads +- k_in_off = kv_off + tl.arange(0, k_dim)[None, :] +- v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] +- k = tl.load(KV_ptr + k_in_off, mask=mask) +- v = tl.load(KV_ptr + v_in_off, mask=mask) ++ KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ k_off = ki_range * stride_kv_nheads + kj_range ++ if v_dim > 0: ++ v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ v = tl.load(KV_ptr + v_off, mask=mask_v) ++ else: ++ v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) ++ k = tl.load(KV_ptr + k_off, mask=mask_k) + +- K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads +- V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads ++ K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads ++ V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads + +- k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] +- v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] +- tl.store(K_ptr + k_out_off, k, mask=mask) +- tl.store(V_ptr + v_out_off, v, mask=mask) ++ k_out_off = ki_range * stride_k_nheads + kj_range ++ tl.store(K_ptr + k_out_off, k, mask=mask_k) ++ if v_dim > 0: ++ v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] ++ tl.store(V_ptr + v_out_off, v, mask=mask_v) + + EMB = K_POS_EMB + pid_m * stride_emb_seq + # x1 = t[..., 0::2], x2 = t[..., 1::2] +@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( + x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + ++ x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ mask_x = x_range < head_num + x_left_off = ( +- tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads ++ x_range * stride_k_nheads + + k_dim + + tl.arange(0, emb_dim // 2)[None, :] + ) + x_right_off = x_left_off + emb_dim // 2 +- tl.store(K_ptr + x_left_off, x_left, mask=mask) +- tl.store(K_ptr + x_right_off, x_right, mask=mask) ++ tl.store(K_ptr + x_left_off, x_left, mask=mask_x) ++ tl.store(K_ptr + x_right_off, x_right, mask=mask_x) + + + @triton.autotune( +@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( + else: + token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) + +- dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads +- dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads +- mask = dkv_off < head_num * stride_dkv_nheads +- dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] +- dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] +- +- dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads +- dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads +- dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] +- dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] +- dk = tl.load(dK_ptr + dk_in_off, mask=mask) +- dv = tl.load(dV_ptr + dv_in_off, mask=mask) +- tl.store(dKV_ptr + dk_out_off, dk, mask=mask) +- tl.store(dKV_ptr + dv_out_off, dv, mask=mask) ++ dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ dk_out_off = ki_range * stride_dkv_nheads + kj_range ++ ++ dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads ++ dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads ++ dk_in_off = ki_range * stride_dk_nheads + kj_range ++ ++ dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) ++ tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) ++ ++ if v_dim > 0: ++ dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] ++ dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) ++ tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) + + if pid_head == 0: + x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): +- dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads +- x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim ++ dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads ++ x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads + mask = x_off < head_num * stride_dk_nheads + x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] + x_right_off = x_left_off + emb_dim // 2 +@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) + o_value = kv.new_empty(total_seqlen, nheads, v_dim) ++ k_dim_ceil = triton.next_power_of_2(k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_fwd_kv_kernel[grid]( +@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + emb_dim, + k_dim, ++ k_dim_ceil, + v_dim, + nheads, + batch_size, +@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) + d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) ++ k_dim_ceil = triton.next_power_of_2(ctx.k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_bwd_kv_kernel[grid]( +@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + ctx.emb_dim, + ctx.k_dim, ++ k_dim_ceil, + ctx.v_dim, + nheads, + batch_size, +diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py +index 13d74aa52..060898a7a 100644 +--- a/megatron/core/models/common/language_module/language_module.py ++++ b/megatron/core/models/common/language_module/language_module.py +@@ -184,7 +184,15 @@ class LanguageModule(MegatronModule): + assert ( + column_parallel_linear is not None + ), "column_parallel_linear cannot be None when not using fused linear cross entropy." +- logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) ++ # output ++ output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} ++ output_layer_buffers = dict(column_parallel_linear.named_buffers()) ++ logits, _ = torch.func.functional_call( ++ column_parallel_linear, ++ {**output_layer_params, **output_layer_buffers}, ++ (hidden,), ++ col_linear_kwargs, ++ ) + + return self.compute_language_model_loss(labels, logits) + +diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py +index e21127b87..712793853 100755 +--- a/megatron/core/models/gpt/gpt_layer_specs.py ++++ b/megatron/core/models/gpt/gpt_layer_specs.py +@@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec( + use_kitchen: bool = False, + use_te_activation_func: bool = False, + fallback_to_eager_attn: bool = False, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). + +@@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec( + mlp=mlp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + normalization=normalization, ++ post_self_attn_layernorm=post_self_attn_layernorm, ++ post_mlp_layernorm=post_mlp_layernorm, + ) + + +@@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend( + mlp: ModuleSpec, + sharded_state_dict_keys_map: Optional[dict] = None, + normalization: Optional[str] = None, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Helper function to get module spec for TransformerLayer""" + +@@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend( + input_layernorm=input_layernorm, + self_attention=attention, + self_attn_bda=get_bias_dropout_add, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=pre_mlp_layernorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + ), + ) +diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py +index a1230568c..1fd52f65a 100644 +--- a/megatron/core/models/gpt/gpt_model.py ++++ b/megatron/core/models/gpt/gpt_model.py +@@ -446,6 +446,7 @@ class GPTModel(LanguageModule): + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, ++ mtp_kwargs: Optional[dict] = {}, + ) -> Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoder and finally into the post +@@ -508,6 +509,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, ++ mtp_kwargs=mtp_kwargs, + ) + + def _postprocess( +@@ -529,6 +531,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, ++ mtp_kwargs={}, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + +@@ -543,7 +546,8 @@ class GPTModel(LanguageModule): + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() +- if mtp_in_postprocess: ++ ++ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, +@@ -563,13 +567,18 @@ class GPTModel(LanguageModule): + return hidden_states + + # Skip when mtp_num_layers is None or 0 +- if self.config.mtp_num_layers: +- mtp_labels = labels.clone() ++ if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None: ++ mtp_labels = mtp_kwargs['mtp_labels'].clone() ++ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) ++ else: ++ # Otherwise, roll the loss_mask to keep up with the mtp_labels ++ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + for mtp_layer_number in range(self.config.mtp_num_layers): + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor( +@@ -595,7 +604,7 @@ class GPTModel(LanguageModule): + sequence_parallel_enabled=self.output_layer.sequence_parallel, + column_parallel_linear=self.output_layer, + col_linear_kwargs={ +- 'weight': output_weight, ++ 'weight': output_weight.detach() if output_weight else None, + 'runtime_gather_output': runtime_gather_output, + }, + ) +diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py +index 6e093f96f..eac21a3ea 100644 +--- a/megatron/core/optimizer/distrib_optimizer.py ++++ b/megatron/core/optimizer/distrib_optimizer.py +@@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + # TE FusedAdam will not accumulate step for empty param groups, so we need to + # align the step across param groups. + param_group["step"] = int(step) ++ if "step" in param_group and param_group["step"] is None: ++ del param_group["step"] + + # Grad scaler state. + if self.grad_scaler: +@@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + if key == 'padding': + tensors[key] = LocalNonpersistentObject(tensors[key]) + continue ++ if key == 'step': ++ continue + assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( + tensors[key].shape, + gbuf_local_start, +diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py +index a273002b9..4f821cfd5 100644 +--- a/megatron/core/parallel_state.py ++++ b/megatron/core/parallel_state.py +@@ -11,6 +11,7 @@ from typing import Callable, List, Optional + + import numpy as np + import torch ++import torch.distributed as dist + + from .utils import GlobalMemoryBuffer, is_torch_min_version + +diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py +index ac839c21f..f18309217 100644 +--- a/megatron/core/pipeline_parallel/p2p_communication.py ++++ b/megatron/core/pipeline_parallel/p2p_communication.py +@@ -26,22 +26,22 @@ def _batched_p2p_ops( + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group ++ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, + ) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, + ) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group ++ torch.distributed.isend, tensor_send_next, next_pipeline_rank, + ) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, + ) + ops.append(recv_next_op) + if len(ops) > 0: +diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py +index 28cff06f5..58dc4bb70 100644 +--- a/megatron/core/transformer/moe/moe_utils.py ++++ b/megatron/core/transformer/moe/moe_utils.py +@@ -587,6 +587,9 @@ def topk_routing_with_score_function( + else: + return torch.topk(scores, k=topk, dim=1) + ++ from miles.utils.routing_replay import get_routing_replay_compute_topk ++ compute_topk = get_routing_replay_compute_topk(compute_topk) ++ + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) +diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py +index 16fc9d9af..517944f25 100644 +--- a/megatron/core/transformer/moe/router.py ++++ b/megatron/core/transformer/moe/router.py +@@ -201,6 +201,9 @@ class TopKRouter(Router): + self.global_tokens_per_expert = None + self.ga_steps = None + ++ from miles.utils.routing_replay import register_routing_replay ++ register_routing_replay(self) ++ + def _maintain_float32_expert_bias(self): + """ + Maintain the expert bias in float32. +diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py +index a8f4abfcd..f33f6f05e 100755 +--- a/megatron/core/transformer/multi_token_prediction.py ++++ b/megatron/core/transformer/multi_token_prediction.py +@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union + + import torch + from torch import Tensor ++import warnings + + from megatron.core import InferenceParams, parallel_state, tensor_parallel + from megatron.core.dist_checkpointing.mapping import ShardedStateDict +@@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule): + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) +- position_ids, _ = roll_tensor( +- position_ids, +- shifts=-1, +- dims=-1, +- cp_group=self.cp_group, +- packed_seq_params=packed_seq_params, +- ) ++ if position_ids is not None: ++ position_ids, _ = roll_tensor( ++ position_ids, ++ shifts=-1, ++ dims=-1, ++ cp_group=self.cp_group, ++ packed_seq_params=packed_seq_params, ++ ) + # embedding + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) ++ decoder_input = decoder_input.detach() + +- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) ++ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) + + return input_ids, position_ids, decoder_input, hidden_states + +@@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule): + return hidden_states + + def _checkpointed_forward(self, forward_func, *args, **kwargs): ++ """Wrap `forward_func` with activation checkpointing while only passing tensors. ++ ++ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so ++ that checkpoint implementations never receive them directly, avoiding save_for_backward ++ issues with non-tensor inputs. ++ """ ++ ++ # TODO(jiajun): Is there any better implementation here? ++ positional_specs = [] ++ kw_specs = [] ++ tensor_args: List[torch.Tensor] = [] ++ ++ for arg in args: ++ if torch.is_tensor(arg): ++ positional_specs.append(('tensor', len(tensor_args))) ++ tensor_args.append(arg) ++ else: ++ positional_specs.append(('const', arg)) ++ ++ for key, value in kwargs.items(): ++ if torch.is_tensor(value): ++ kw_specs.append((key, ('tensor', len(tensor_args)))) ++ tensor_args.append(value) ++ else: ++ kw_specs.append((key, ('const', value))) ++ ++ def run(*flat_tensor_args): ++ rebuilt_args = [] ++ for spec_type, payload in positional_specs: ++ if spec_type == 'tensor': ++ rebuilt_args.append(flat_tensor_args[payload]) ++ else: ++ rebuilt_args.append(payload) ++ ++ rebuilt_kwargs = {} ++ for key, (spec_type, payload) in kw_specs: ++ if spec_type == 'tensor': ++ rebuilt_kwargs[key] = flat_tensor_args[payload] ++ else: ++ rebuilt_kwargs[key] = payload ++ ++ return forward_func(*rebuilt_args, **rebuilt_kwargs) ++ ++ tensor_args_tuple = tuple(tensor_args) ++ + def checkpoint_handler(): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: +@@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule): + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), +- *args, +- **kwargs, ++ *tensor_args_tuple, + ) + else: + return tensor_parallel.checkpoint( +- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() ++ run, self.config.distribute_saved_activations, *tensor_args_tuple + ) + + if self.config.recompute_method == 'uniform': +diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py +index e2705bd9f..a0aa109b5 100644 +--- a/megatron/core/transformer/transformer_config.py ++++ b/megatron/core/transformer/transformer_config.py +@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig): + attention_output_gate: bool = False + """Whether to apply output gate to the attention layers.""" + ++ post_self_attn_layernorm: bool = False ++ post_mlp_layernorm: bool = False ++ + test_mode: bool = False + """Whether to run real-time tests.""" + +diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py +index 3ea405770..5a42001b9 100644 +--- a/megatron/core/transformer/transformer_layer.py ++++ b/megatron/core/transformer/transformer_layer.py +@@ -223,6 +223,7 @@ class TransformerLayerSubmodules: + input_layernorm: Union[ModuleSpec, type] = IdentityOp + self_attention: Union[ModuleSpec, type] = IdentityOp + self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + + pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + cross_attention: Union[ModuleSpec, type] = IdentityOp +@@ -231,6 +232,7 @@ class TransformerLayerSubmodules: + pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + mlp: Union[ModuleSpec, type] = IdentityOp + mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + + # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method + sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) +@@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + # [Module 3: BiasDropoutFusion] + self.self_attn_bda = build_module(submodules.self_attn_bda) + ++ self.post_self_attn_layernorm = build_module( ++ submodules.post_self_attn_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon, ++ ) ++ + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn + self.pre_cross_attn_layernorm = build_module( + submodules.pre_cross_attn_layernorm, +@@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + + self.is_moe_layer = isinstance(self.mlp, MoELayer) + ++ self.post_mlp_layernorm = build_module( ++ submodules.post_mlp_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon ++ ) ++ + self.recompute_input_layernorm = False + self.recompute_pre_mlp_layernorm = False + self.recompute_mlp = False +@@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + attention_output_with_bias[0] + ) + ++ attention_output, attention_output_bias = attention_output_with_bias ++ attention_output = self.post_self_attn_layernorm(attention_output) ++ attention_output_with_bias = (attention_output, attention_output_bias) ++ + # TODO: could we move `bias_dropout_add_exec_handler` itself + # inside the module provided in the `bias_dropout_add_spec` module? + nvtx_range_push(suffix="self_attn_bda") +@@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + else: + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + ++ mlp_output, mlp_output_bias = mlp_output_with_bias ++ mlp_output = self.post_mlp_layernorm(mlp_output) ++ mlp_output_with_bias = (mlp_output, mlp_output_bias) ++ + if self.recompute_pre_mlp_layernorm: + # discard the output of the pre-mlp layernorm and register the recompute + # as a gradient hook of mlp_output_with_bias[0] +diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py +index b267c8a81..83736acdc 100644 +--- a/megatron/training/arguments.py ++++ b/megatron/training/arguments.py +@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None): + + kw_args['inference_sampling_seed'] = args.seed + ++ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm ++ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm ++ + # handle quantization config + # NOTE: Kitchen arguments are only added to the namespace when + # Kitchen library is available. +@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser): + action='store_true', + help='If set, use original BERT residula connection ' + 'ordering.') ++ group.add_argument('--post-self-attn-layernorm', action='store_true', ++ help='If set, use post self attention layernorm.') ++ group.add_argument('--post-mlp-layernorm', action='store_true', ++ help='If set, use post MLP layernorm.') ++ group.add_argument('--use-gated-attention', action='store_true', ++ help='If set, use gated attention as in Qwen3Next') + group.add_argument('--openai-gelu', action='store_true', + help='Use OpenAIs GeLU implementation. This option' + 'should not be used unless for backward compatibility' +diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py +index 13b7526ca..6c590f653 100644 +--- a/megatron/training/tokenizer/tokenizer.py ++++ b/megatron/training/tokenizer/tokenizer.py +@@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer): + # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, +- trust_remote_code=trust_remote_code, ++ trust_remote_code=True, + **kwargs, + ) + self._vocab = self._tokenizer.get_vocab() diff --git a/docker/patch/v0.5.8/sglang.patch b/docker/patch/v0.5.8/sglang.patch new file mode 100644 index 000000000..6c7bb46d4 --- /dev/null +++ b/docker/patch/v0.5.8/sglang.patch @@ -0,0 +1,1439 @@ +diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py +index 0a7a86cab..fc570be37 100644 +--- a/python/sglang/srt/configs/model_config.py ++++ b/python/sglang/srt/configs/model_config.py +@@ -272,10 +272,13 @@ class ModelConfig: + ): + self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" + +- if is_draft_model and self.hf_config.architectures[0] in [ +- "Glm4MoeForCausalLM", +- "Glm4MoeLiteForCausalLM", +- ]: ++ if ( ++ is_draft_model ++ and self.hf_config.architectures[0] == "DeepseekV32ForCausalLM" ++ ): ++ self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" ++ ++ if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM": + self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN" + + if ( +diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py +index 4738b032f..8dabccd7a 100644 +--- a/python/sglang/srt/disaggregation/decode.py ++++ b/python/sglang/srt/disaggregation/decode.py +@@ -315,6 +315,13 @@ class DecodePreallocQueue: + ) + return kv_manager + ++ def release_memory_occupation(self): ++ if hasattr(self.kv_manager, "close"): ++ self.kv_manager.close() ++ ++ def resume_memory_occupation(self): ++ self.kv_manager = self._init_kv_manager() ++ + def add(self, req: Req, is_retracted: bool = False) -> None: + """Add a request to the pending queue.""" + if self._check_if_req_exceed_kv_capacity(req): +diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py +index a0c80e0d1..2f0105249 100644 +--- a/python/sglang/srt/disaggregation/mooncake/conn.py ++++ b/python/sglang/srt/disaggregation/mooncake/conn.py +@@ -1079,6 +1079,19 @@ class MooncakeKVManager(CommonKVManager): + f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected" + ) + ++ def close(self): ++ # Batch deregister KV data buffers ++ if self.kv_args.kv_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.kv_data_ptrs) ++ ++ # Batch deregister auxiliary data buffers ++ if self.kv_args.aux_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.aux_data_ptrs) ++ ++ # Batch deregister state/extra pool data buffers ++ if self.kv_args.state_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.state_data_ptrs) ++ + + class MooncakeKVSender(CommonKVSender): + +diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py +index 39a824c3a..dbf80a3e3 100644 +--- a/python/sglang/srt/disaggregation/prefill.py ++++ b/python/sglang/srt/disaggregation/prefill.py +@@ -309,6 +309,13 @@ class PrefillBootstrapQueue: + else: + return bootstrapped_reqs, failed_reqs + ++ def release_memory_occupation(self): ++ if hasattr(self.kv_manager, "close"): ++ self.kv_manager.close() ++ ++ def resume_memory_occupation(self): ++ self.kv_manager = self._init_kv_manager() ++ + + class SchedulerDisaggregationPrefillMixin: + """ +diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py +index b01595526..01a0f0fa3 100644 +--- a/python/sglang/srt/distributed/parallel_state.py ++++ b/python/sglang/srt/distributed/parallel_state.py +@@ -1810,7 +1810,10 @@ def get_tensor_model_parallel_world_size(): + + def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" +- return get_tp_group().rank_in_group ++ try: ++ return get_tp_group().rank_in_group ++ except Exception: ++ return 0 + + + def get_pipeline_model_parallel_world_size(): +diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py +index 3cd85df93..c363f809a 100644 +--- a/python/sglang/srt/entrypoints/engine.py ++++ b/python/sglang/srt/entrypoints/engine.py +@@ -50,6 +50,7 @@ from sglang.srt.managers.io_struct import ( + LoadLoRAAdapterFromTensorsReqInput, + LoadLoRAAdapterReqInput, + MultimodalDataInputFormat, ++ PostProcessWeightsReqInput, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, + RpcReqInput, +@@ -594,6 +595,24 @@ class Engine(EngineBase): + self.tokenizer_manager.update_weights_from_ipc(obj, None) + ) + ++ def post_process_weights( ++ self, ++ restore_weights_before_load: bool = False, ++ post_process_quantization: bool = False, ++ ): ++ """ ++ Optional post-processing for updated weights (e.g., Marlin conversion). ++ Should be called after weight update is finished. ++ """ ++ obj = PostProcessWeightsReqInput( ++ restore_weights_before_load=restore_weights_before_load, ++ post_process_quantization=post_process_quantization, ++ ) ++ ++ return self.loop.run_until_complete( ++ self.tokenizer_manager.post_process_weights(obj, None) ++ ) ++ + def get_weights_by_name(self, name: str, truncate_size: int = 100): + """Get weights by parameter name.""" + obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) +diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py +index afac1d03d..435fff418 100644 +--- a/python/sglang/srt/entrypoints/http_server.py ++++ b/python/sglang/srt/entrypoints/http_server.py +@@ -108,6 +108,7 @@ from sglang.srt.managers.io_struct import ( + OpenSessionReqInput, + ParseFunctionCallReq, + PauseGenerationReqInput, ++ PostProcessWeightsReqInput, + ProfileReqInput, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, +@@ -956,6 +957,21 @@ async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Re + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + ++@app.post("/post_process_weights") ++async def post_process_weights(req: PostProcessWeightsReqInput, request: Request): ++ """ ++ Optional post-processing for updated weights (e.g., Marlin conversion). ++ This should be called selectively after `update_weights_from_distributed/update_weights_from_tensor`. ++ """ ++ success, message = await _global_state.tokenizer_manager.post_process_weights( ++ req, request ++ ) ++ ++ content = {"success": success, "message": message} ++ return ORJSONResponse( ++ content, status_code=200 if success else HTTPStatus.BAD_REQUEST ++ ) ++ + + @app.post("/update_weight_version") + @auth_level(AuthLevel.ADMIN_OPTIONAL) +diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +index 7f18f131d..383776414 100644 +--- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py ++++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +@@ -4,6 +4,7 @@ import contextlib + from abc import ABC, abstractmethod + from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + ++import os + import torch + from einops import rearrange + +@@ -196,7 +197,7 @@ class Indexer(MultiPlatformOp): + max_position=max_position_embeddings, + base=rope_theta, # type: ignore + rope_scaling=rope_scaling, +- is_neox_style=True, ++ is_neox_style=True if os.environ.get("INDEXER_ROPE_NEOX_STYLE", "1") == "1" else False, + device=get_global_server_args().device, + ) + self.block_size = block_size +@@ -227,8 +228,10 @@ class Indexer(MultiPlatformOp): + + @torch.compile(dynamic=True) if not _is_hip else lambda f: f + def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor): +- weights, _ = self.weights_proj(x) +- weights = weights.float() ++ weights, _ = self.weights_proj(x.float()) ++ if weights.shape[1] < 32: ++ assert 32 % weights.shape[1] == 0 ++ weights = weights.repeat_interleave(32 // weights.shape[1], dim=1) + weights = weights * self.n_heads**-0.5 + weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale + return weights +@@ -945,6 +948,13 @@ class Indexer(MultiPlatformOp): + return_indices, + ) + ++ query, key = self._get_q_k_bf16( ++ q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch ++ ) ++ if query.shape[1] < 32: ++ assert 32 % query.shape[1] == 0 ++ query = query.repeat_interleave(32//query.shape[1], dim=1) ++ + if enable_dual_stream and forward_batch.forward_mode.is_decode_or_idle(): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) +diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py +index d7b4cbd1f..b4cca391e 100644 +--- a/python/sglang/srt/layers/layernorm.py ++++ b/python/sglang/srt/layers/layernorm.py +@@ -83,15 +83,12 @@ class RMSNorm(MultiPlatformOp): + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + cast_x_before_out_mul: bool = False, +- fp32_residual: bool = False, +- weight_dtype: Optional = None, +- override_orig_dtype: Optional = None, ++ fp32_residual: bool = True, + ) -> None: + super().__init__() + self.cast_x_before_out_mul = cast_x_before_out_mul + self.fp32_residual = fp32_residual +- self.override_orig_dtype = override_orig_dtype +- self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype)) ++ self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.hidden_size = hidden_size + self.variance_size_override = ( +@@ -140,6 +137,8 @@ class RMSNorm(MultiPlatformOp): + post_residual_addition: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: ++ if post_residual_addition is not None: ++ residual = residual + post_residual_addition + out, _, residual_out = torch_npu.npu_add_rms_norm( + residual, x, self.weight.data, self.variance_epsilon + ) +@@ -153,6 +152,8 @@ class RMSNorm(MultiPlatformOp): + post_residual_addition: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: ++ if post_residual_addition is not None: ++ residual = residual + post_residual_addition + residual_out = torch.empty_like(x) + output = torch.empty_like(x) + fused_add_rms_norm( +@@ -176,6 +177,8 @@ class RMSNorm(MultiPlatformOp): + # NOTE: Remove this if aiter kernel supports discontinuous input + x = x.contiguous() + if residual is not None: ++ if post_residual_addition is not None: ++ residual = residual + post_residual_addition + out = torch.empty_like(x) + residual_out = torch.empty_like(x) + fused_add_rms_norm( +@@ -194,16 +197,19 @@ class RMSNorm(MultiPlatformOp): + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if not x.is_contiguous(): + x = x.contiguous() +- orig_dtype = self.override_orig_dtype or x.dtype ++ orig_dtype = x.dtype ++ ++ if residual is not None and not self.fp32_residual: ++ x = x + residual ++ if post_residual_addition is not None: ++ x = x + post_residual_addition ++ residual = x.clone() + x = x.to(torch.float32) +- if residual is not None: ++ if residual is not None and self.fp32_residual: + x = x + residual.to(torch.float32) + if post_residual_addition is not None: + x = x + post_residual_addition.to(torch.float32) +- if self.fp32_residual: +- residual = x.clone() +- else: +- residual = x.to(orig_dtype) ++ residual = x.to(orig_dtype) + + hidden_size = x.shape[-1] + if hidden_size != self.hidden_size: +@@ -244,6 +250,8 @@ class RMSNorm(MultiPlatformOp): + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if _is_cpu_amx_available: + if residual is not None: ++ if post_residual_addition is not None: ++ residual = residual + post_residual_addition + torch.ops.sgl_kernel.fused_add_rmsnorm_cpu( + x, residual, self.weight.data, self.variance_epsilon + ) +@@ -263,6 +271,8 @@ class RMSNorm(MultiPlatformOp): + if self.variance_size_override is not None: + return self.forward_native(x, residual, post_residual_addition) + if residual is not None: ++ if post_residual_addition is not None: ++ residual = residual + post_residual_addition + fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) + return x, residual + out = rmsnorm(x, self.weight.data, self.variance_epsilon) +@@ -284,6 +294,8 @@ class RMSNorm(MultiPlatformOp): + ) + + if get_tensor_model_parallel_world_size() > 1: ++ if post_residual_addition is not None: ++ x = x + post_residual_addition + fused_result = flashinfer_allreduce_residual_rmsnorm( + input_tensor=x, + residual=residual, +@@ -389,6 +401,8 @@ class GemmaRMSNorm(MultiPlatformOp): + post_residual_addition: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: ++ if post_residual_addition is not None: ++ residual = residual + post_residual_addition + gemma_fused_add_rmsnorm( + x, residual, self.weight.data, self.variance_epsilon + ) +@@ -405,6 +419,8 @@ class GemmaRMSNorm(MultiPlatformOp): + orig_dtype = x.dtype + if residual is not None: + x = x + residual ++ if post_residual_addition is not None: ++ x = x + post_residual_addition + residual = x + + x = x.float() +@@ -430,6 +446,8 @@ class GemmaRMSNorm(MultiPlatformOp): + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if _is_cpu_amx_available: + if residual is not None: ++ if post_residual_addition is not None: ++ residual = residual + post_residual_addition + torch.ops.sgl_kernel.gemma_fused_add_rmsnorm_cpu( + x, residual, self.weight.data, self.variance_epsilon + ) +@@ -447,6 +465,8 @@ class GemmaRMSNorm(MultiPlatformOp): + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + x = x + residual ++ if post_residual_addition is not None: ++ x = x + post_residual_addition + residual = x + + x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon) +diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py +index fa7431048..cd33ea735 100644 +--- a/python/sglang/srt/layers/logits_processor.py ++++ b/python/sglang/srt/layers/logits_processor.py +@@ -878,11 +878,6 @@ class LogitsProcessor(nn.Module): + None, # bias + True, # is_vnni + ) +- elif get_global_server_args().rl_on_policy_target is not None: +- # Due to tie-weight, we may not be able to change lm_head's weight dtype +- logits = torch.matmul( +- hidden_states.bfloat16(), lm_head.weight.T.bfloat16() +- ) + else: + logits = torch.matmul( + hidden_states.to(lm_head.weight.dtype), lm_head.weight.T +diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +index a1885fade..14d692365 100644 +--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py ++++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +@@ -14,6 +14,7 @@ import torch.nn.functional as F + import triton.language as tl + + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig ++from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, +@@ -573,7 +574,10 @@ def fused_experts_impl( + ).squeeze(dim=1) + else: + # According to micro benchmark results, torch.compile can get better performance for small token. +- if tokens_in_chunk <= 32: ++ if ( ++ not get_global_server_args().enable_deterministic_inference ++ and tokens_in_chunk <= 32 ++ ): + moe_sum_reduce_torch_compile( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], +diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +index b5f90f255..feffd3cae 100644 +--- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py ++++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +@@ -667,7 +667,7 @@ class FusedMoE(torch.nn.Module): + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", + ] +- ) ++ ) and "zero" not in weight_name + else loaded_weight + ) + +diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py +index 00bd68755..5a3ca8a67 100644 +--- a/python/sglang/srt/layers/moe/routed_experts_capturer.py ++++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py +@@ -1,5 +1,6 @@ + import logging + from abc import ABC ++from contextlib import contextmanager + from typing import Optional + + import numpy as np +@@ -8,13 +9,18 @@ import torch + + from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.layers.dp_attention import ( ++ attn_tp_all_gather_into_tensor, + get_attention_dp_rank, ++ get_attention_tp_size, + get_dp_local_info, + is_dp_attention_enabled, + ) + from sglang.srt.mem_cache.memory_pool import ReqToTokenPool + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + from sglang.srt.server_args import get_global_server_args ++from sglang.srt.layers.moe import ( ++ get_moe_a2a_backend, ++) + + logger = logging.getLogger(__name__) + +@@ -181,13 +187,26 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): + device=device, + ) + ++ if get_moe_a2a_backend().is_deepep(): ++ attn_tp_size = get_attention_tp_size() if is_dp_attention_enabled() else 1 ++ self.gather_buffer = torch.empty( ++ ( ++ self.device_cache.buffer.shape[0] * attn_tp_size, ++ self.device_cache.buffer.shape[2], ++ ), ++ dtype=torch.int32, ++ device=device, ++ ) ++ + def _sync_fwd_experts_buffer_DtoH( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: int, + ): +- if is_dp_attention_enabled(): ++ # When DeepEP is enabled, capture() already does all_gather, so device_cache.buffer ++ # contains data from all DP ranks. We should not slice by DP rank in this case. ++ if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep(): + local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) + # handle with cuda graph padding + if can_run_graph: +@@ -206,6 +225,12 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): + ].cpu() + + def capture(self, layer_id: int, topk_ids: torch.Tensor): ++ if get_moe_a2a_backend().is_deepep(): ++ local_topk_ids = topk_ids ++ topk_ids = self.gather_buffer[ ++ : local_topk_ids.size(0) * get_attention_tp_size() ++ ] ++ attn_tp_all_gather_into_tensor(topk_ids, local_topk_ids) + self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) + + def get_routed_experts( +diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +index 8253f6ebe..c32b865e3 100644 +--- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py ++++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +@@ -469,7 +469,7 @@ class CompressedTensorsConfig(QuantizationConfig): + ) + is_static = not weight_quant.dynamic + +- return is_channel_group and input_quant_none and is_symmetric and is_static ++ return is_channel_group and input_quant_none and is_static + + def _is_dynamic_token_w4( + self, weight_quant: BaseModel, input_quant: BaseModel +diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +index 4bfb7abf3..e8cb15e3b 100644 +--- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py ++++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +@@ -35,7 +35,10 @@ from sglang.srt.layers.quantization.fp8_utils import ( + normalize_e4m3fn_to_e4m3fnuz, + ) + from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack +-from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales ++from sglang.srt.layers.quantization.marlin_utils import ( ++ marlin_moe_permute_scales, ++ moe_awq_to_marlin_zero_points ++) + from sglang.srt.layers.quantization.utils import ( + all_close_1d, + per_tensor_dequantize, +@@ -1011,7 +1014,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + self.strategy = config.strategy + self.group_size = config.group_size + self.actorder = config.actorder +- assert config.symmetric, "Only symmetric quantization is supported for MoE" ++ self.sym = config.symmetric + + if not ( + self.quant_config.quant_format == CompressionFormat.pack_quantized.value +@@ -1066,7 +1069,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + + # In the case where we have actorder/g_idx, + # we do not partition the w2 scales +- load_full_w2 = self.actorder and self.group_size != -1 ++ load_full_w2 = (self.actorder != 'static') and self.group_size != -1 + + if load_full_w2: + w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size +@@ -1114,6 +1117,32 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + ++ # add zero param ++ if not self.sym: ++ w13_qzeros = torch.nn.Parameter( ++ torch.empty( ++ num_experts, ++ num_groups_w13, ++ 2 * intermediate_size_per_partition // self.packed_factor, ++ dtype=torch.int32, ++ ), ++ requires_grad=False, ++ ) ++ layer.register_parameter("w13_weight_zero_point", w13_qzeros) ++ set_weight_attrs(w13_qzeros, extra_weight_attrs) ++ ++ w2_qzeros = torch.nn.Parameter( ++ torch.empty( ++ num_experts, ++ num_groups_w2, ++ hidden_size // self.packed_factor, ++ dtype=torch.int32, ++ ), ++ requires_grad=False, ++ ) ++ layer.register_parameter("w2_weight_zero_point", w2_qzeros) ++ set_weight_attrs(w2_qzeros, extra_weight_attrs) ++ + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, +@@ -1167,13 +1196,22 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + + # Force record: these are the target GPTQ shapes for rollback. + layer._original_shapes["w13_weight_packed"] = tuple(w13_weight.shape) +- layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape) ++ layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape) ++ if not self.sym: ++ layer._original_shapes["w13_weight_zero_point"] = w13_qzeros.shape + +- # Also record the shapes of the scales. ++ layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape) + layer._original_shapes["w2_weight_scale"] = tuple(w2_scale.shape) +- layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape) ++ if not self.sym: ++ layer._original_shapes["w2_weight_zero_point"] = tuple(w2_qzeros.shape) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ++ # Skip if the layer is already converted to Marlin format to prevent double-packing. ++ if getattr(layer, "is_marlin_converted", False): ++ return ++ ++ if not hasattr(layer, "_original_shapes"): ++ layer._original_shapes = {} + + # Skip if the layer is already converted to Marlin format to prevent double-packing. + if getattr(layer, "is_marlin_converted", False): +@@ -1276,11 +1314,28 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + ) + replace_tensor("w2_weight_scale", marlin_w2_scales) + ++ # Repack zero ++ if not self.sym: ++ marlin_w13_zp = moe_awq_to_marlin_zero_points( ++ layer.w13_weight_zero_point, ++ size_k=layer.w13_weight_zero_point.shape[1], ++ size_n=layer.w13_weight_zero_point.shape[2] * self.packed_factor, ++ num_bits=self.num_bits, ++ ) ++ replace_tensor("w13_weight_zero_point", marlin_w13_zp) ++ ++ marlin_w2_zp = moe_awq_to_marlin_zero_points( ++ layer.w2_weight_zero_point, ++ size_k=layer.w2_weight_zero_point.shape[1], ++ size_n=layer.w2_weight_zero_point.shape[2] * self.packed_factor, ++ num_bits=self.num_bits, ++ ) ++ replace_tensor("w2_weight_zero_point", marlin_w2_zp) ++ + layer.is_marlin_converted = True + + def restore_weights_before_loading(self, layer: torch.nn.Module): + """Forcibly resize parameters back to their original shapes (e.g., GPTQ format) before loading weights.""" +- + if not hasattr(layer, "_original_shapes"): + return + +@@ -1341,6 +1396,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + g_idx2=layer.w2_weight_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, ++ w1_zeros=layer.w13_weight_zero_point if not self.sym else None, ++ w2_zeros=layer.w2_weight_zero_point if not self.sym else None, + num_bits=self.num_bits, + is_k_full=self.is_k_full, + routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, +diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py +index 8fbdf3160..582f3d675 100644 +--- a/python/sglang/srt/layers/rotary_embedding.py ++++ b/python/sglang/srt/layers/rotary_embedding.py +@@ -136,9 +136,7 @@ class RotaryEmbedding(MultiPlatformOp): + + if get_global_server_args().rl_on_policy_target is not None: + self._forward_method = self.forward_native +- self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)( +- self._apply_rotary_emb_wrapped +- ) ++ + self.position_cos, self.position_sin = None, None + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: +@@ -1578,6 +1576,9 @@ class MRotaryEmbedding(RotaryEmbedding): + key: torch.Tensor, + fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: ++ assert ( ++ fused_set_kv_buffer_arg is None ++ ), "fused_set_kv_buffer_arg is not supported for npu implementation" + # TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096 + assert ( + fused_set_kv_buffer_arg is None +diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py +index 55bef5652..35ad68b1c 100644 +--- a/python/sglang/srt/layers/sampler.py ++++ b/python/sglang/srt/layers/sampler.py +@@ -108,16 +108,11 @@ class Sampler(nn.Module): + if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB: + probs_without_temp_scaling = torch.softmax(logits, dim=-1) + +- if get_global_server_args().rl_on_policy_target is not None: +- logits_div_temperature = ( +- logits.bfloat16().div(sampling_info.temperatures).bfloat16() +- ) +- logprobs_via_logsoftmax_kernel = torch.log_softmax( +- logits_div_temperature, dim=-1 +- ) +- + # Post process logits + logits.div_(sampling_info.temperatures) ++ if get_global_server_args().rl_on_policy_target is not None: ++ logprobs_via_logsoftmax_kernel = torch.log_softmax(logits, dim=-1) ++ + # For ascend backend, softmax is not needed before sampling + if not get_global_server_args().sampling_backend == "ascend" or ( + return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB +diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py +index fad02e0a0..72ceea737 100644 +--- a/python/sglang/srt/managers/io_struct.py ++++ b/python/sglang/srt/managers/io_struct.py +@@ -1306,6 +1306,19 @@ class UpdateWeightsFromIPCReqOutput(BaseReq): + success: bool + message: str + ++@dataclass ++class PostProcessWeightsReqInput(BaseReq): ++ # Whether to restore weights before loading new weights ++ restore_weights_before_load: bool = False ++ # Whether to enable quantization post-processing ++ post_process_quantization: bool = False ++ ++ ++@dataclass ++class PostProcessWeightsReqOutput(BaseReq): ++ success: bool ++ message: str ++ + + @dataclass + class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq): +diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py +index f3ca299bd..c01ce4975 100644 +--- a/python/sglang/srt/managers/schedule_batch.py ++++ b/python/sglang/srt/managers/schedule_batch.py +@@ -2332,7 +2332,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + def __str__(self): + return ( + f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " +- f"#req={(len(self.reqs))})" ++ f"#req={(len(self.reqs))}), " ++ f"#out_cache_loc={self.out_cache_loc})" + ) + + +diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py +index 99f14e3ef..fcb03a8b5 100644 +--- a/python/sglang/srt/managers/scheduler.py ++++ b/python/sglang/srt/managers/scheduler.py +@@ -103,6 +103,7 @@ from sglang.srt.managers.io_struct import ( + OpenSessionReqInput, + OpenSessionReqOutput, + PauseGenerationReqInput, ++ PostProcessWeightsReqInput, + ProfileReq, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, +@@ -1040,6 +1041,7 @@ class Scheduler( + ), + (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), + (UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc), ++ (PostProcessWeightsReqInput, self.post_process_weights), + (GetWeightsByNameReqInput, self.get_weights_by_name), + (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), + (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), +diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py +index c4728b714..43890f734 100644 +--- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py ++++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py +@@ -10,6 +10,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode + from sglang.srt.environ import envs + from sglang.srt.layers.logits_processor import LogitsProcessorOutput + from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer ++ + from sglang.srt.managers.io_struct import ( + AbortReq, + BatchEmbeddingOutput, +@@ -1069,7 +1070,7 @@ class SchedulerOutputProcessorMixin: + req.log_time_stats() + + # Send to detokenizer +- if reqs or is_idle_batch: ++ if rids or is_idle_batch: + if self.model_config.is_multimodal_gen: + return + self.send_to_detokenizer.send_output( +diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py +index 293a84350..c3a618bcc 100644 +--- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py ++++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py +@@ -1,6 +1,7 @@ + from __future__ import annotations + + import logging ++import os + import traceback + from typing import TYPE_CHECKING, Tuple + +@@ -12,6 +13,9 @@ from sglang.srt.constants import ( + GPU_MEMORY_TYPE_KV_CACHE, + GPU_MEMORY_TYPE_WEIGHTS, + ) ++from sglang.srt.disaggregation.utils import DisaggregationMode ++from sglang.srt.distributed import get_moe_ep_group, get_moe_tp_group, get_tp_group ++from sglang.srt.layers.dp_attention import get_attention_tp_group + from sglang.srt.managers.io_struct import ( + CheckWeightsReqInput, + CheckWeightsReqOutput, +@@ -21,6 +25,8 @@ from sglang.srt.managers.io_struct import ( + GetWeightsByNameReqOutput, + InitWeightsUpdateGroupReqInput, + InitWeightsUpdateGroupReqOutput, ++ PostProcessWeightsReqInput, ++ PostProcessWeightsReqOutput, + ReleaseMemoryOccupationReqInput, + ReleaseMemoryOccupationReqOutput, + ResumeMemoryOccupationReqInput, +@@ -114,6 +120,11 @@ class SchedulerUpdateWeightsMixin: + torch.distributed.barrier(group=self.tp_cpu_group) + return UpdateWeightsFromIPCReqOutput(success, message) + ++ def post_process_weights(self, recv_req: PostProcessWeightsReqInput): ++ """Optional post-processing for updated weights (e.g., Marlin conversion).""" ++ success, message = self.tp_worker.post_process_weights(recv_req) ++ return PostProcessWeightsReqOutput(success, message) ++ + def get_weights_by_name(self: Scheduler, recv_req: GetWeightsByNameReqInput): + parameter = self.tp_worker.get_weights_by_name(recv_req) + return GetWeightsByNameReqOutput(parameter) +@@ -137,6 +148,13 @@ class SchedulerUpdateWeightsMixin: + self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) + self.flush_cache() + ++ if self.disaggregation_mode == DisaggregationMode.DECODE: ++ if hasattr(self, "disagg_decode_prealloc_queue"): ++ self.disagg_decode_prealloc_queue.release_memory_occupation() ++ elif self.disaggregation_mode == DisaggregationMode.PREFILL: ++ if hasattr(self, "disagg_prefill_bootstrap_queue"): ++ self.disagg_prefill_bootstrap_queue.release_memory_occupation() ++ + if GPU_MEMORY_TYPE_WEIGHTS in tags: + self.stashed_model_static_state = _export_static_state( + self.tp_worker.model_runner.model +@@ -177,6 +195,13 @@ class SchedulerUpdateWeightsMixin: + if GPU_MEMORY_TYPE_KV_CACHE in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) + ++ if self.disaggregation_mode == DisaggregationMode.DECODE: ++ if hasattr(self, "disagg_decode_prealloc_queue"): ++ self.disagg_decode_prealloc_queue.resume_memory_occupation() ++ elif self.disaggregation_mode == DisaggregationMode.PREFILL: ++ if hasattr(self, "disagg_prefill_bootstrap_queue"): ++ self.disagg_prefill_bootstrap_queue.resume_memory_occupation() ++ + return ResumeMemoryOccupationReqOutput() + + def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): +diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py +index e25729e71..e2075bb85 100644 +--- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py ++++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py +@@ -53,6 +53,8 @@ from sglang.srt.managers.io_struct import ( + LoadLoRAAdapterReqOutput, + LoRAUpdateOutput, + OpenSessionReqInput, ++ PostProcessWeightsReqInput, ++ PostProcessWeightsReqOutput, + ProfileReq, + ProfileReqOutput, + ProfileReqType, +@@ -181,6 +183,9 @@ class TokenizerCommunicatorMixin: + self.update_weights_from_ipc_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) ++ self.post_process_weights_communicator = _Communicator( ++ self.send_to_scheduler, server_args.dp_size ++ ) + self.get_weights_by_name_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) +@@ -257,6 +262,10 @@ class TokenizerCommunicatorMixin: + UpdateWeightsFromIPCReqOutput, + self.update_weights_from_ipc_communicator.handle_recv, + ), ++ ( ++ PostProcessWeightsReqOutput, ++ self.post_process_weights_communicator.handle_recv, ++ ), + ( + GetWeightsByNameReqOutput, + self.get_weights_by_name_communicator.handle_recv, +@@ -444,6 +453,17 @@ class TokenizerCommunicatorMixin: + + return success, message + ++ async def post_process_weights( ++ self: TokenizerManager, ++ obj: PostProcessWeightsReqInput, ++ request: Optional[fastapi.Request] = None, ++ ) -> Tuple[bool, str]: ++ """Trigger post-processing hooks for weights after loading (e.g., Marlin conversion).""" ++ self.auto_create_handle_loop() ++ async with self.model_update_lock.writer_lock: ++ results = await self.post_process_weights_communicator(obj) ++ return _Communicator.merge_results(results) ++ + async def init_weights_send_group_for_remote_instance( + self, + obj: InitWeightsSendGroupForRemoteInstanceReqInput, +diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py +index af4000729..2d6b98502 100644 +--- a/python/sglang/srt/managers/tp_worker.py ++++ b/python/sglang/srt/managers/tp_worker.py +@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import ( + InitWeightsUpdateGroupReqInput, + LoadLoRAAdapterFromTensorsReqInput, + LoadLoRAAdapterReqInput, ++ PostProcessWeightsReqInput, + SendWeightsToRemoteInstanceReqInput, + UnloadLoRAAdapterReqInput, + UpdateWeightFromDiskReqInput, +@@ -167,6 +168,11 @@ class BaseTpWorker(ABC): + success, message = self.model_runner.update_weights_from_ipc(recv_req) + return success, message + ++ def post_process_weights(self, recv_req: PostProcessWeightsReqInput): ++ """Perform optional post-processing on the updated model weights (e.g., Marlin conversion).""" ++ success, message = self.model_runner.post_process_weights(recv_req) ++ return success, message ++ + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): + parameter = self.model_runner.get_weights_by_name( + recv_req.name, recv_req.truncate_size +diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py +index 378421fc6..6aaf31f29 100644 +--- a/python/sglang/srt/mem_cache/memory_pool.py ++++ b/python/sglang/srt/mem_cache/memory_pool.py +@@ -1732,7 +1732,8 @@ class NSATokenToKVPool(MLATokenToKVPool): + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.custom_mem_pool +- else nullcontext() ++ else nullcontext(), ++ self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE), + ): + self.index_k_with_scale_buffer = [ + torch.zeros( +diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py +index af67d52e1..ecdeb05ed 100644 +--- a/python/sglang/srt/model_executor/model_runner.py ++++ b/python/sglang/srt/model_executor/model_runner.py +@@ -566,7 +566,8 @@ class ModelRunner(ModelRunnerKVCacheMixin): + ) + + # Init routed experts capturer +- self.init_routed_experts_capturer() ++ if not self.is_draft_worker: ++ self.init_routed_experts_capturer() + + if self.device == "cuda": + self.init_cublas() +@@ -2263,11 +2264,19 @@ class ModelRunner(ModelRunnerKVCacheMixin): + output.expert_distribution_metrics = recorder_outputs.get("metrics") + + # Copy cached routing experts' buffers back to CPU cache +- get_global_experts_capturer().on_forward_end( +- forward_batch=forward_batch, +- can_run_graph=output.can_run_graph, +- cuda_graph_batch=getattr(self.graph_runner, "bs", None), +- ) ++ if not self.is_draft_worker: ++ # In speculative decoding, num_tokens_per_bs > 1, so we need to pass ++ # the actual number of tokens per dp rank in cuda graph, not batch size. ++ cuda_graph_num_tokens = None ++ if getattr(self.graph_runner, "bs", None): ++ cuda_graph_num_tokens = ( ++ self.graph_runner.bs * self.graph_runner.num_tokens_per_bs ++ ) ++ get_global_experts_capturer().on_forward_end( ++ forward_batch=forward_batch, ++ can_run_graph=output.can_run_graph, ++ cuda_graph_batch=cuda_graph_num_tokens, ++ ) + + if self.eplb_manager is not None: + self.eplb_manager.on_forward_pass_end() +@@ -2475,6 +2484,41 @@ class ModelRunner(ModelRunnerKVCacheMixin): + logger.error(f"IPC weight update failed: {e}") + return False, str(e) + ++ def post_process_weights(self, recv_req): ++ """ ++ Execute post-processing logic for model weights, such as Marlin quantization format conversion. ++ """ ++ from sglang.srt.model_loader.loader import device_loading_context ++ ++ target_device = torch.device("cuda", torch.cuda.current_device()) ++ ++ if recv_req.restore_weights_before_load: ++ for _, module in self.model.named_modules(): ++ quant_method = getattr(module, "quant_method", None) ++ ++ # Check if the module supports restoring weights ++ if quant_method is not None and hasattr( ++ quant_method, "restore_weights_before_loading" ++ ): ++ ++ with device_loading_context(module, target_device): ++ quant_method.restore_weights_before_loading(module) ++ ++ if recv_req.post_process_quantization: ++ # Iterate through all modules to apply specific post-loading processing ++ for _, module in self.model.named_modules(): ++ quant_method = getattr(module, "quant_method", None) ++ ++ # Check if the module supports quantization post-processing ++ if quant_method is not None and hasattr( ++ quant_method, "process_weights_after_loading" ++ ): ++ ++ # Apply the post-processing (e.g., repacking weights for Marlin kernel) ++ with device_loading_context(module, target_device): ++ quant_method.process_weights_after_loading(module) ++ ++ return True, "Success" + + def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]): + params_dict = dict(model.named_parameters()) +diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py +index b6986977a..76a6e80cd 100644 +--- a/python/sglang/srt/models/deepseek_v2.py ++++ b/python/sglang/srt/models/deepseek_v2.py +@@ -151,6 +151,7 @@ from sglang.srt.utils import ( + make_layers, + use_intel_amx_backend, + ) ++from sglang.srt.layers.attention.hybrid_attn_backend import HybridAttnBackend + + if _use_aiter_gfx95: + +diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py +index a7dbadec6..c83a41338 100644 +--- a/python/sglang/srt/models/qwen2.py ++++ b/python/sglang/srt/models/qwen2.py +@@ -90,9 +90,6 @@ class Qwen2MLP(nn.Module): + self.act_fn = SiluAndMul() + + def forward(self, x): +- if get_global_server_args().rl_on_policy_target is not None: +- x = x.bfloat16() +- + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) +@@ -279,11 +276,6 @@ class Qwen2Model(nn.Module): + quant_config=quant_config, + enable_tp=not is_dp_attention_enabled(), + prefix=add_prefix("embed_tokens", prefix), +- params_dtype=( +- torch.float32 +- if get_global_server_args().rl_on_policy_target is not None +- else None +- ), + ) + else: + self.embed_tokens = PPMissingLayer() +@@ -306,10 +298,8 @@ class Qwen2Model(nn.Module): + if self.pp_group.is_last_rank: + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, +- override_orig_dtype=torch.float32, +- fp32_residual=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py +index 1e4b53a7d..a24e34e8b 100644 +--- a/python/sglang/srt/models/qwen2_moe.py ++++ b/python/sglang/srt/models/qwen2_moe.py +@@ -592,7 +592,17 @@ class Qwen2MoeModel(nn.Module): + prefix=add_prefix("layers", prefix), + ) + if self.pp_group.is_last_rank: +- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.norm = RMSNorm( ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs ++ ) + else: + self.norm = PPMissingLayer(return_tuple=True) + +diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py +index 89871ad57..b4ad8adfd 100644 +--- a/python/sglang/srt/models/qwen3.py ++++ b/python/sglang/srt/models/qwen3.py +@@ -90,8 +90,8 @@ class Qwen3Attention(nn.Module): + + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +@@ -242,10 +242,8 @@ class Qwen3DecoderLayer(nn.Module): + + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, +- override_orig_dtype=torch.float32, +- fp32_residual=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py +index d0469a62e..2ab3d7e0e 100644 +--- a/python/sglang/srt/models/qwen3_moe.py ++++ b/python/sglang/srt/models/qwen3_moe.py +@@ -22,6 +22,7 @@ import math + from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar + + import torch ++import torch.nn.functional as F + from torch import nn + from transformers import PretrainedConfig + +@@ -50,7 +51,7 @@ from sglang.srt.layers.moe import ( + ) + from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +-from sglang.srt.layers.moe.topk import TopK ++from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK + from sglang.srt.layers.moe.utils import ( + RoutingMethodType, + filter_moe_weight_param_global_expert, +@@ -232,6 +233,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): + use_grouped_topk=False, + layer_id=layer_id, + ) ++ self.top_k = config.num_experts_per_tok + + self.experts = get_moe_impl_class(quant_config)( + num_experts=config.num_experts +@@ -300,7 +302,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module): + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) +- topk_output = self.topk(hidden_states, router_logits) ++ ++ if get_global_server_args().rl_on_policy_target is not None: ++ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) ++ routing_weights, selected_experts = torch.topk( ++ routing_weights, self.top_k, dim=-1 ++ ) ++ routing_weights /= routing_weights.sum(dim=-1, keepdim=True) ++ routing_weights = routing_weights.to(hidden_states.dtype) ++ topk_output = StandardTopKOutput( ++ topk_weights=routing_weights, ++ topk_ids=selected_experts, ++ router_logits=router_logits, ++ ) ++ else: ++ topk_output = self.topk(hidden_states, router_logits) ++ + final_hidden_states = self.experts(hidden_states, topk_output) + if ( + self.tp_size > 1 +@@ -481,13 +498,14 @@ class Qwen3MoeAttention(nn.Module): + ) + self.compatible_with_fused_kv_buffer = ( + False if isinstance(self.rotary_emb, MRotaryEmbedding) else True +- ) ++ ) and (get_global_server_args().rl_on_policy_target is None) + self.compatible_with_fused_qk_norm_rope = ( + not isinstance(self.rotary_emb, MRotaryEmbedding) + ) and self.head_dim in (64, 128, 256) + self.use_fused_qk_norm_rope = ( + get_global_server_args().enable_fused_qk_norm_rope + and self.compatible_with_fused_qk_norm_rope ++ and (get_global_server_args().rl_on_policy_target is None) + ) + self._used_fused_qk_norm_rope_last_call = False + +@@ -500,8 +518,16 @@ class Qwen3MoeAttention(nn.Module): + prefix=add_prefix("attn", prefix), + ) + +- self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) +- self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) ++ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) + self.alt_stream = alt_stream + + def op_prepare(self, state): +@@ -742,9 +768,19 @@ class Qwen3MoeDecoderLayer(nn.Module): + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) +- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.input_layernorm = RMSNorm( ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs ++ ) + self.post_attention_layernorm = RMSNorm( +- config.hidden_size, eps=config.rms_norm_eps ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs + ) + + self.layer_communicator = LayerCommunicator( +diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py +index 93661fc78..89ed0b5ae 100644 +--- a/python/sglang/srt/models/qwen3_vl.py ++++ b/python/sglang/srt/models/qwen3_vl.py +@@ -397,28 +397,68 @@ class Qwen3VLMoeVisionModel(nn.Module, RotaryPosMixin): + return cos_combined, sin_combined + + def fast_pos_embed_interpolate(self, grid_thw): +- patch_pos_embeds_permute = [] +- m_size = self.spatial_merge_size ++ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] ++ num_grid_per_side = int(self.num_position_embeddings**0.5) ++ device = self.pos_embed.weight.device ++ ++ idx_list = [[] for _ in range(4)] ++ weight_list = [[] for _ in range(4)] ++ ++ for t, h, w in zip(grid_ts, grid_hs, grid_ws): ++ h_idxs = torch.linspace(0, num_grid_per_side - 1, h) ++ w_idxs = torch.linspace(0, num_grid_per_side - 1, w) ++ ++ h_idxs_floor = h_idxs.int() ++ w_idxs_floor = w_idxs.int() ++ h_idxs_ceil = (h_idxs.int() + 1).clip(max=num_grid_per_side - 1) ++ w_idxs_ceil = (w_idxs.int() + 1).clip(max=num_grid_per_side - 1) ++ ++ dh = h_idxs - h_idxs_floor ++ dw = w_idxs - w_idxs_floor ++ ++ base_h = h_idxs_floor * num_grid_per_side ++ base_h_ceil = h_idxs_ceil * num_grid_per_side ++ ++ indices = [ ++ (base_h[None].T + w_idxs_floor[None]).flatten(), ++ (base_h[None].T + w_idxs_ceil[None]).flatten(), ++ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), ++ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), ++ ] ++ ++ weights = [ ++ ((1 - dh)[None].T * (1 - dw)[None]).flatten(), ++ ((1 - dh)[None].T * dw[None]).flatten(), ++ (dh[None].T * (1 - dw)[None]).flatten(), ++ (dh[None].T * dw[None]).flatten(), ++ ] + +- embeds = torch.arange(self.num_grid, device=self.pos_embed.weight.device) +- embeds = ( +- self.pos_embed(embeds) +- .permute(1, 0) +- .reshape(1, -1, self.num_grid_per_side, self.num_grid_per_side) ++ for i in range(4): ++ idx_list[i].extend(indices[i].tolist()) ++ weight_list[i].extend(weights[i].tolist()) ++ ++ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) ++ weight_tensor = torch.tensor( ++ weight_list, dtype=self.pos_embed.weight.dtype, device=device + ) +- for t, h, w in grid_thw: +- pos_embed = torch.nn.functional.interpolate( +- embeds, size=(h, w), mode="bilinear", align_corners=self.align_corners +- ) +- pos_embed = pos_embed.reshape( +- -1, +- h // self.spatial_merge_size, +- self.spatial_merge_size, +- w // self.spatial_merge_size, +- self.spatial_merge_size, ++ pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] ++ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] ++ ++ patch_pos_embeds = patch_pos_embeds.split( ++ [h * w for h, w in zip(grid_hs, grid_ws)] ++ ) ++ ++ patch_pos_embeds_permute = [] ++ merge_size = self.spatial_merge_size ++ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): ++ pos_embed = pos_embed.repeat(t, 1) ++ pos_embed = ( ++ pos_embed.view( ++ t, h // merge_size, merge_size, w // merge_size, merge_size, -1 ++ ) ++ .permute(0, 1, 3, 2, 4, 5) ++ .flatten(0, 4) + ) +- pos_embed = pos_embed.permute(1, 3, 2, 4, 0) +- pos_embed = pos_embed.flatten(0, 3).repeat(t, 1) + patch_pos_embeds_permute.append(pos_embed) + return torch.cat(patch_pos_embeds_permute) + +@@ -610,14 +650,19 @@ class Qwen3LLMModel(Qwen3Model): + hidden_states + residual if residual is not None else hidden_states + ) + ++ deepstack_embeds = None ++ if input_deepstack_embeds is not None: ++ prev_layer_idx = layer_idx - 1 ++ if prev_layer_idx in self.deepstack_embed_to_decoder_layer: ++ sep = self.hidden_size * prev_layer_idx ++ deepstack_embeds = input_deepstack_embeds[ ++ :, sep : sep + self.hidden_size ++ ] ++ + # SGLang applies residual at the START of the next layer, not at the END like HuggingFace. + # See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549 + # To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack + # The order matters because addition with different tensors is not associative in practice. +- # Deepstack for prev_layer is applied at the start of current layer via post_residual_addition. +- deepstack_embeds = self.get_deepstack_embeds( +- layer_idx - 1, input_deepstack_embeds +- ) + hidden_states, residual = layer( + positions, + hidden_states, +diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py +index 684324b71..9a827c1ac 100644 +--- a/python/sglang/srt/server_args.py ++++ b/python/sglang/srt/server_args.py +@@ -555,6 +555,7 @@ class ServerArgs: + cuda_graph_max_bs: Optional[int] = None + cuda_graph_bs: Optional[List[int]] = None + disable_cuda_graph: bool = False ++ disable_draft_cuda_graph: bool = False + disable_cuda_graph_padding: bool = False + enable_profile_cuda_graph: bool = False + enable_cudagraph_gc: bool = False +@@ -4160,6 +4161,11 @@ class ServerArgs: + action="store_true", + help="Disable cuda graph.", + ) ++ parser.add_argument( ++ "--disable-draft-cuda-graph", ++ action="store_true", ++ help="Disable cuda graph for draft model in speculative decoding.", ++ ) + parser.add_argument( + "--disable-cuda-graph-padding", + action="store_true", +diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +index 5fe45086c..c95fbd0f6 100644 +--- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py ++++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +@@ -341,7 +341,10 @@ class EAGLEDraftCudaGraphRunner: + self.seq_lens.fill_(self.seq_len_fill_value) + self.out_cache_loc.zero_() + self.positions.zero_() +- ++ self.topk_p.zero_() ++ self.topk_index.zero_() ++ self.hidden_states.zero_() ++ self.req_pool_indices.zero_() + num_tokens = bs * self.num_tokens_per_bs + + # Common inputs +@@ -350,8 +353,8 @@ class EAGLEDraftCudaGraphRunner: + forward_batch.out_cache_loc + ) + self.positions[:raw_num_token].copy_(forward_batch.positions) +- self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p) +- self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) ++ self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p.clamp(0, 1)) ++ self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index.clamp(0, self.model_runner.model_config.vocab_size - 1)) + self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) + self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + +diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py +index e72c9d725..3314cb9f7 100644 +--- a/python/sglang/srt/speculative/eagle_info.py ++++ b/python/sglang/srt/speculative/eagle_info.py +@@ -777,6 +777,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): + self.topk_index = self.topk_index[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.verified_id = self.verified_id[: len(new_indices)] ++ if self.accept_length is not None: ++ self.accept_length = self.accept_length[: len(new_indices)] ++ if self.accept_length_cpu is not None: ++ self.accept_length_cpu = self.accept_length_cpu[: len(new_indices)] + else: + # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` + self.topk_p = self.topk_p[new_indices] +@@ -808,6 +812,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) + self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) + self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) ++ if self.accept_length is not None and spec_info.accept_length is not None: ++ self.accept_length = torch.cat( ++ [self.accept_length, spec_info.accept_length] ++ ) ++ self.accept_length_cpu = self.accept_length.tolist() ++ elif self.accept_length is not None: ++ zeros = torch.zeros( ++ [spec_info.verified_id.shape[0]], ++ dtype=self.accept_length.dtype, ++ device=self.accept_length.device, ++ ) ++ self.accept_length = torch.cat([self.accept_length, zeros]) ++ self.accept_length_cpu = self.accept_length.tolist() ++ elif spec_info.accept_length is not None: ++ zeros = torch.zeros( ++ [self.verified_id.shape[0]], ++ dtype=self.accept_length.dtype, ++ device=self.accept_length.device, ++ ) ++ self.accept_length = torch.cat([zeros, spec_info.accept_length]) ++ self.accept_length_cpu = self.accept_length.tolist() + + + @dataclass +diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py +index 5a6cc4b03..b06efed72 100644 +--- a/python/sglang/srt/speculative/eagle_worker.py ++++ b/python/sglang/srt/speculative/eagle_worker.py +@@ -230,7 +230,7 @@ class EAGLEWorker(TpModelWorker): + self.cuda_graph_runner = None + self.cuda_graph_runner_for_draft_extend = None + +- if self.server_args.disable_cuda_graph: ++ if self.server_args.disable_cuda_graph or self.server_args.disable_draft_cuda_graph: + return + + Device2DraftCudaGraphRunner = { +diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py +index df1b5f066..f90a9307c 100644 +--- a/python/sglang/srt/utils/common.py ++++ b/python/sglang/srt/utils/common.py +@@ -2244,6 +2244,8 @@ class SafeUnpickler(pickle.Unpickler): + "sglang.srt.model_executor.model_runner.", + "sglang.srt.layers.", + "sglang.srt.utils.", ++ # --- slime --- ++ "slime.", + } + + DENY_CLASSES = {