diff --git a/chatlearn/algorithm/grpo_utils/policy_trainer.py b/chatlearn/algorithm/grpo_utils/policy_trainer.py
index b5d17fd4..b636cb25 100644
--- a/chatlearn/algorithm/grpo_utils/policy_trainer.py
+++ b/chatlearn/algorithm/grpo_utils/policy_trainer.py
@@ -22,6 +22,8 @@
import torch.distributed as dist
import torch.nn.functional as F
from flash_attn.bert_padding import pad_input
+from packaging.version import Version as PkgVersion
+import transformers
from chatlearn import FSDPModule
from chatlearn.utils import to_device
@@ -37,7 +39,6 @@
split_and_unpadding,
unpad_input)
-
class PolicyTrainer(FSDPModule):
"""policy trainer"""
def setup(self):
@@ -120,6 +121,8 @@ def preprocess_data_list(self, data_list: List[Dict[str, Any]], training: bool):
if self.runtime_args.model_type == 'vlm':
# vl
position_ids = position_ids.permute(0, 2, 1).cpu()
+ if PkgVersion(transformers.__version__)>=PkgVersion('4.55.0'):
+ position_ids = torch.cat([position_ids[0:1], position_ids], dim=0) # add text position_ids for vl
else:
position_ids = position_ids.permute(1, 0).cpu() # For compatible with transformers
@@ -168,6 +171,19 @@ def preprocess_data_list(self, data_list: List[Dict[str, Any]], training: bool):
data_after_process.append(data_obj)
return response_token_length_total, data_after_process
+ def compute_vl_position_ids(self, data_list: List[Dict[str, Any]]):
+ input_ids_key = 'input_ids' if 'input_ids' in data_list[0] else 'prompt_token_ids'
+
+ for data_b in data_list:
+ position_ids, _ = self.model.model.get_rope_index(
+ input_ids=torch.tensor(data_b[input_ids_key]).unsqueeze(0),
+ image_grid_thw=data_b["image_grid_thw"],
+ attention_mask=torch.tensor(data_b['attention_mask']).unsqueeze(0)
+ )
+ data_b['position_ids'] = position_ids.squeeze().tolist()
+
+ return data_list
+
@monitor_error()
@compute_decorator(trainable=True, rollout=False)
@timeit()
@@ -298,6 +314,9 @@ def train_step(self, data_list: List[Dict[str, Any]], **kwargs): # pylint: disab
@compute_decorator(trainable=False, rollout=False)
@timeit()
def forward_step(self, data: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any]]: # pylint: disable=unused-argument,arguments-differ
+ if self.runtime_args.model_type == 'vlm':
+ data = self.compute_vl_position_ids(data)
+
_, data_list = self.preprocess_data_list(data_list=data, training=False)
tag = "old_logprobs" if self.trainable else "ref_logprobs"
# Logprobs holder
diff --git a/chatlearn/data/vl_prompt_dataset.py b/chatlearn/data/vl_prompt_dataset.py
index 0d4e69b9..b97d1c38 100644
--- a/chatlearn/data/vl_prompt_dataset.py
+++ b/chatlearn/data/vl_prompt_dataset.py
@@ -6,9 +6,6 @@
from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
-from chatlearn.models.patches.transformers.qwen2_5_vl_patch import get_rope_index
-
-
class PromptPipeline(Dataset):
"""
Input data_list: List[Dict])
@@ -41,6 +38,7 @@ class PromptPipeline(Dataset):
"mm_processor_kwargs": {'fps':[]}, # used for video useless now
"pixel_values": Tensor, # [grid_num, pixel_num]
"image_grid_thw": Tensor, # [1,3] 3 means t,h,w
+ "attention_mask": List, used for compute position_ids
}
"""
def __init__(
@@ -98,15 +96,6 @@ def __init__(
# text only input_ids for vllm
raw_input_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
- # get position_ids used for sequence packing
- position_ids, _ = get_rope_index(
- self.processor,
- input_ids=input_ids,
- image_grid_thw=model_inputs.get("image_grid_thw"),
- video_grid_thw=model_inputs.get("video_grid_thw"),
- second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
- attention_mask=attention_mask,
- )
# for vl model, raw_input_ids is only text input_ids for vllm inference
# input_ids is used for model forward_step and sglang inference (with image pad)
@@ -116,11 +105,11 @@ def __init__(
"input_ids": input_ids[0].tolist(),
"prompt_token_length": len(input_ids[0].tolist()),
"prompt": raw_prompt,
- "position_ids": position_ids.squeeze().tolist(),
"multi_modal_data": multi_modal_data,
"mm_processor_kwargs": mm_processor_kwargs,
"pixel_values": pixel_values,
- "image_grid_thw": image_grid_thw
+ "image_grid_thw": image_grid_thw,
+ "attention_mask": attention_mask[0].tolist(),
})
if len(input_ids[0]) > self.max_prompt:
self.max_prompt = len(input_ids[0])
diff --git a/chatlearn/models/agent/agent_module.py b/chatlearn/models/agent/agent_module.py
index 98aceda9..ab74b368 100644
--- a/chatlearn/models/agent/agent_module.py
+++ b/chatlearn/models/agent/agent_module.py
@@ -89,7 +89,8 @@ def postprocess_func(
prompt_token_ids = output.prompt_ids
pixel_values = output.pixel_values
image_grid_thw = output.image_grid_thw
- position_ids = output.position_ids
+ attentiion_mask = output.attention_mask
+
response_token_length = len(output.all_token_ids) - len(output.prompt_ids)
prompt_token_length = len(output.prompt_ids)
str_outputs = output.str_output
@@ -108,7 +109,7 @@ def postprocess_func(
# multimodel related
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thw,
- "position_ids": position_ids
+ "attention_mask": attentiion_mask
}
)
data_output.append(input_data)
diff --git a/chatlearn/models/agent/base_agent_graph.py b/chatlearn/models/agent/base_agent_graph.py
index f6532220..b9f7a9c6 100644
--- a/chatlearn/models/agent/base_agent_graph.py
+++ b/chatlearn/models/agent/base_agent_graph.py
@@ -26,9 +26,6 @@
find_last_ai_index,
find_first_ai_index)
from chatlearn.models.sglang_module import AsyncEngine
-from chatlearn.models.patches.transformers.qwen2_5_vl_patch import get_rope_index
-
-
def find_first_zero_group_end(lst):
for i, x in enumerate(lst):
if x != 0:
@@ -55,7 +52,7 @@ class AgentGraphOutput(BaseModel):
# multimodel related item
pixel_values: Any = None
image_grid_thw: Any = None
- position_ids: Any = None
+ attention_mask: Any = None
# Extra fields for dynamic addition.
extra_fields: dict[str, Any] = {}
@@ -108,21 +105,12 @@ def convert_agent_graph_output(self, messages: Dict) -> AgentGraphOutput:
num_turns = last_ai_message_idx + 1
str_output = self.tokenizer.decode(all_token_ids[prompt_end_idx + 1 :])
- pixel_values, image_grid_thw, position_ids = None, None, None
+ pixel_values, image_grid_thw, attention_mask = None, None, None
multimodel_batch_feature = messages[first_ai_message_idx].response_metadata.get("multimodel_batch_feature", None)
if multimodel_batch_feature:
pixel_values = multimodel_batch_feature.get("pixel_values")
image_grid_thw = multimodel_batch_feature.get("image_grid_thw")
- # need to get position ids used in sequence packing
- position_ids, _ = get_rope_index(
- self.processor,
- input_ids=multimodel_batch_feature.get("input_ids"),
- image_grid_thw=multimodel_batch_feature.get("image_grid_thw"),
- video_grid_thw=multimodel_batch_feature.get("video_grid_thw"),
- second_per_grid_ts=multimodel_batch_feature.get("second_per_grid_ts"),
- attention_mask=multimodel_batch_feature.get("attention_mask"),
- )
- position_ids = position_ids.squeeze().tolist()
+ attention_mask = multimodel_batch_feature.get("attention_mask")[0].tolist()
return AgentGraphOutput(
str_output=str_output,
@@ -132,5 +120,5 @@ def convert_agent_graph_output(self, messages: Dict) -> AgentGraphOutput:
num_turns=num_turns,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
- position_ids=position_ids
+ attention_mask=attention_mask
)
diff --git a/chatlearn/models/fsdp_module.py b/chatlearn/models/fsdp_module.py
index 9709e2b7..1ecdb313 100644
--- a/chatlearn/models/fsdp_module.py
+++ b/chatlearn/models/fsdp_module.py
@@ -23,7 +23,6 @@
import numpy as np
from packaging.version import Version as PkgVersion
-import transformers
import torch
from torch import Tensor
import torch.distributed as dist
@@ -258,10 +257,9 @@ def create_model(self, model_path: str , torch_dtype: torch.dtype, meta_init: bo
attn_implementation="flash_attention_2",
trust_remote_code=self.module_args.trust_remote_code
)
- if PkgVersion(transformers.__version__)==PkgVersion('4.51.3'):
- # vl patch needed for transformers 4.51.3
- from chatlearn.models.patches.monkey_patch import apply_qwenvl
- apply_qwenvl(model)
+
+ from chatlearn.models.patches.monkey_patch import apply_qwenvl
+ apply_qwenvl(model)
assert self.sp_size == 1, "VL model only support sp_size=1"
else:
@@ -273,14 +271,26 @@ def create_model(self, model_path: str , torch_dtype: torch.dtype, meta_init: bo
)
else:
model_config = AutoConfig.from_pretrained(model_path)
- assert "Qwen2_5_VLForConditionalGeneration" not in model_config.architectures, "VL model not support meta init"
- with init_on_device('meta', include_buffers=False):
- model = AutoModelForCausalLM.from_config(
- model_config,
- torch_dtype=torch_dtype,
- attn_implementation="flash_attention_2",
- trust_remote_code=self.module_args.trust_remote_code
- )
+
+ if self.runtime_args.model_type == 'vlm':
+ with init_on_device('meta', include_buffers=False):
+ model = AutoModelForImageTextToText.from_pretrained(
+ pretrained_model_name_or_path=model_path,
+ torch_dtype=torch_dtype,
+ attn_implementation="flash_attention_2",
+ trust_remote_code=self.module_args.trust_remote_code
+ )
+
+ from chatlearn.models.patches.monkey_patch import apply_qwenvl
+ apply_qwenvl(model)
+ else:
+ with init_on_device('meta', include_buffers=False):
+ model = AutoModelForCausalLM.from_config(
+ model_config,
+ torch_dtype=torch_dtype,
+ attn_implementation="flash_attention_2",
+ trust_remote_code=self.module_args.trust_remote_code
+ )
dist.barrier()
return model
@property
@@ -507,9 +517,13 @@ def get_weight_ipc_handles_by_name(self, block_name: List[str]):
if rollout_engine == "sglang":
# lazy import sglang
from sglang.srt.utils import MultiprocessingSerializer
- from sglang.srt.patch_torch import monkey_patch_torch_reductions
-
+ import sglang
+ if PkgVersion(sglang.__version__)>=PkgVersion('0.5.3'):
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
+ else:
+ from sglang.srt.patch_torch import monkey_patch_torch_reductions
monkey_patch_torch_reductions()
+
flattened_tensor, metadatas = self.convert_block2flattened_bucket(
block_parameter
)
diff --git a/chatlearn/models/patches/monkey_patch.py b/chatlearn/models/patches/monkey_patch.py
index a0be47d1..c7c281bb 100644
--- a/chatlearn/models/patches/monkey_patch.py
+++ b/chatlearn/models/patches/monkey_patch.py
@@ -13,6 +13,9 @@
# limitations under the License.
# ==============================================================================
"""Apply patches for different model architectures"""
+from packaging.version import Version as PkgVersion
+import transformers
+
def apply_sp_monkey_patch(model_config):
print(f"applying sequence parallel patches for {model_config.architectures}")
if model_config.architectures[0] == "Qwen2ForCausalLM":
@@ -42,8 +45,15 @@ def apply_group_gemm(model):
def apply_qwenvl(model):
print(f"applying qwenvl patches for {model.config.architectures[0]}")
if model.config.architectures[0] == "Qwen2_5_VLForConditionalGeneration":
- from chatlearn.models.patches.transformers.qwen2_5_vl_patch import apply_qwenvl_patch \
+ if PkgVersion(transformers.__version__)==PkgVersion('4.51.3'):
+ # vl2.5 patch needed for transformers 4.51.3
+ from chatlearn.models.patches.transformers.qwen2_5_vl_patch import apply_qwenvl_patch \
+ # pylint: disable=import-outside-toplevel
+ apply_qwenvl_patch()
+ elif model.config.architectures[0] in ["Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration"]:
+ assert PkgVersion(transformers.__version__)>=PkgVersion('4.57.0'), "qwen3vl needed transformers >= 4.57.0"
+ from chatlearn.models.patches.transformers.qwen3_vl_patch import apply_qwen3vl_patch \
# pylint: disable=import-outside-toplevel
- apply_qwenvl_patch()
+ apply_qwen3vl_patch()
else:
raise ValueError(f"Unsupported model architecture: {model.config.architectures} for qwenvl patch")
diff --git a/chatlearn/models/patches/transformers/qwen3_vl_patch.py b/chatlearn/models/patches/transformers/qwen3_vl_patch.py
new file mode 100644
index 00000000..c1145ff7
--- /dev/null
+++ b/chatlearn/models/patches/transformers/qwen3_vl_patch.py
@@ -0,0 +1,33 @@
+"""patches for qwen3 vl model"""
+from typing import Optional
+import torch
+
+def Qwen3VLBlock_patched_forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ # =========================================================================
+ # add force dype change for qwen3_vl or backward will occur type error
+ hidden_states = hidden_states.to(self.norm1.weight.dtype)
+ # =========================================================================
+
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+def apply_qwen3vl_patch():
+ # pylint: disable=import-outside-toplevel
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionBlock
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionBlock
+ Qwen3VLVisionBlock.forward = Qwen3VLBlock_patched_forward
+ Qwen3VLMoeVisionBlock.forward = Qwen3VLBlock_patched_forward
diff --git a/chatlearn/models/sglang_module.py b/chatlearn/models/sglang_module.py
index dfb91970..1ded1602 100644
--- a/chatlearn/models/sglang_module.py
+++ b/chatlearn/models/sglang_module.py
@@ -570,8 +570,12 @@ def parameter_sync(self):
def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']]):
"""Used for Mcore2SGLang Parameter Sync
"""
- from sglang.srt.patch_torch import monkey_patch_torch_reductions
+ if PkgVersion(sglang.__version__)>=PkgVersion('0.5.3'):
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
+ else:
+ from sglang.srt.patch_torch import monkey_patch_torch_reductions
monkey_patch_torch_reductions()
+
param_id_to_update = set()
for bucket in buckets:
if bucket is None:
@@ -766,8 +770,12 @@ async def update_weights_from_ipc_handles(self, reduce_data):
@torch.no_grad()
async def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']]):
- from sglang.srt.patch_torch import monkey_patch_torch_reductions
+ if PkgVersion(sglang.__version__)>=PkgVersion('0.5.3'):
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
+ else:
+ from sglang.srt.patch_torch import monkey_patch_torch_reductions
monkey_patch_torch_reductions()
+
param_id_to_update = set()
for bucket in buckets:
if bucket is None:
diff --git a/docker/torch/Dockerfile.torch2.8.0.sglang053 b/docker/torch/Dockerfile.torch2.8.0.sglang053
new file mode 100644
index 00000000..085e044d
--- /dev/null
+++ b/docker/torch/Dockerfile.torch2.8.0.sglang053
@@ -0,0 +1,15 @@
+FROM nvcr.io/nvidia/pytorch:24.12-py3
+ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/
+ENV PIP_TRUSTED_HOST=mirrors.aliyun.com
+RUN pip install --no-cache-dir "sglang[all]==0.5.3.post1"
+RUN pip install --no-cache-dir transformers==4.57.0
+RUN pip install --no-cache-dir langgraph==0.6.6
+RUN pip install --no-cache-dir ray[default]==2.46.0
+RUN pip install --no-cache-dir accelerate==1.10.0
+RUN pip install --no-cache-dir wandb==0.19.3
+RUN pip install --no-cache-dir hydra-core==1.3.2
+RUN pip install --no-cache-dir grpcio==1.70.0 nvidia-modelopt==0.27.0 nvidia-modelopt-core==0.27.0 datasets==3.6.0 deepspeed==0.16.7
+RUN pip install --no-cache-dir mathruler==0.1.0 pylatexenc==2.10 qwen-vl-utils==0.0.14
+RUN pip uninstall -y flash_attn && pip install --no-cache-dir https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/flash-attention/torch2.8.0-cu12x/flash_attn-2.7.4.post1-cp312-cp312-linux_x86_64.whl
+RUN pip uninstall -y transformer_engine && pip install --no-cache-dir https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/transformer_engine/torch2.8.0-cuda12x/transformer_engine-2.3.0%2B5de3e148-cp312-cp312-linux_x86_64.whl
+RUN pip uninstall -y apex && pip install --no-cache-dir https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/apex/torch2.8.0-cuda12x/apex-0.1-cp312-cp312-linux_x86_64.whl
\ No newline at end of file
diff --git a/docs/en/installation.md b/docs/en/installation.md
index 8c8aa9a0..81ef42ca 100644
--- a/docs/en/installation.md
+++ b/docs/en/installation.md
@@ -14,11 +14,19 @@ dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2
### SGLang
-You can prepare the image by referring to [Dockerfile.torch2.8.0.sglang052](https://github.com/alibaba/ChatLearn/blob/main/docker/torch/Dockerfile.torch2.8.0.sglang052). Alternatively, you can directly pull and use the following image:
+You can prepare the image by referring to [Dockerfile.torch2.8.0.sglang052](https://github.com/alibaba/ChatLearn/blob/main/docker/torch/Dockerfile.torch2.8.0.sglang053). Alternatively, you can directly pull and use the following image:
```bash
-dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.2-ubuntu24.04-cuda12.6-py312
+dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.3-ubuntu24.04-cuda12.6-py312
```
+### Image History
+
+| Image URL | Pkg Version | Model List |
+| ------------------------------------------------------------ | ----------------------------------------- | ------------------------------------------ |
+| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.3-ubuntu24.04-cuda12.6-py312 | sglang 0.5.3.post1
transformers 4.57.0 | Qwen3-VL
Qwen2.5-VL
Qwen3
Qwen2.5 |
+| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.3-ubuntu24.04-cuda12.6-py312 | sglang 0.5.2
transformers 4.56.1 | Qwen2.5-VL
Qwen3
Qwen2.5 |
+| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-te2.7-ubuntu24.04-cuda12.6-py312 | vllm 0.8.5
transformer_engine 2.7 | Moonlight
Deepseek-r1 |
+| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312 | vllm 0.8.5
transformers 4.51.3 | Qwen2.5-VL
Qwen3
Qwen2.5 |
## 2. Code Preparation
diff --git a/docs/zh/installation.md b/docs/zh/installation.md
index befa2ac5..5aa8db78 100644
--- a/docs/zh/installation.md
+++ b/docs/zh/installation.md
@@ -13,12 +13,21 @@ dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2
### SGLang
-可以参考 [Dockerfile.torch2.8.0.sglang052](https://github.com/alibaba/ChatLearn/blob/main/docker/torch/Dockerfile.torch2.8.0.sglang052) 准备镜像。也可以直接拉取如下镜像地址直接进行使用。
+可以参考 [Dockerfile.torch2.8.0.sglang053](https://github.com/alibaba/ChatLearn/blob/main/docker/torch/Dockerfile.torch2.8.0.sglang053) 准备镜像。也可以直接拉取如下镜像地址直接进行使用。
```bash
-dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.2-ubuntu24.04-cuda12.6-py312
+dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.3-ubuntu24.04-cuda12.6-py312
```
+### 镜像历史
+
+| 镜像地址 | 包版本 | 模型列表 |
+| ------------------------------------------------------------ | ----------------------------------------- | ------------------------------------------ |
+| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.3-ubuntu24.04-cuda12.6-py312 | sglang 0.5.3.post1
transformers 4.57.0 | Qwen3-VL
Qwen2.5-VL
Qwen3
Qwen2.5 |
+| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.3-ubuntu24.04-cuda12.6-py312 | sglang 0.5.2
transformers 4.56.1 | Qwen2.5-VL
Qwen3
Qwen2.5 |
+| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-te2.7-ubuntu24.04-cuda12.6-py312 | vllm 0.8.5
transformer_engine 2.7 | Moonlight
Deepseek-r1 |
+| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312 | vllm 0.8.5
transformers 4.51.3 | Qwen2.5-VL
Qwen3
Qwen2.5 |
+
## 2. 代码准备
```
diff --git a/scripts/fsdp_sglang/train_fsdp_sglang_qwen3_vl_30b_grpo.sh b/scripts/fsdp_sglang/train_fsdp_sglang_qwen3_vl_30b_grpo.sh
new file mode 100644
index 00000000..af60f23f
--- /dev/null
+++ b/scripts/fsdp_sglang/train_fsdp_sglang_qwen3_vl_30b_grpo.sh
@@ -0,0 +1,50 @@
+#!/bin/bash
+
+# Tested on 8xH20-3e with 140G VRAM
+set -x
+
+export CHATLEARN=$(pwd)
+export PYTHONPATH=${CHATLEARN}:${PYTHONPATH}
+source scripts/base_env.sh
+export RAY_DEDUP_LOGS=1
+export exp_name=qwen3-vl-grpo-30b-sglang
+
+python chatlearn/entrypoint.py grpo \
+ --config-file template/grpo_fsdp.yaml \
+ runtime_args.exp_name=${exp_name} \
+ runtime_args.rollout_backend=sglang \
+ runtime_args.model_type=vlm \
+ runtime_args.data_path=${CHATLEARN}/dataset/geo3k/train.parquet \
+ runtime_args.eval_data_path=${CHATLEARN}/dataset/geo3k/test.parquet \
+ runtime_args.output_dir=${CHATLEARN}/output/${exp_name} \
+ runtime_args.num_episode=200 \
+ runtime_args.sample_per_episode=512 \
+ runtime_args.train_global_batch_size=512 \
+ runtime_args.train_micro_batch_size=8 \
+ runtime_args.save_episode_interval=5 \
+ runtime_args.eval_episode_interval=5 \
+ runtime_args.enable_eval_before_training=False \
+ runtime_args.log_args_dict.enable_wandb=False \
+ runtime_args.log_args_dict.wandb_project=your_wandb_project \
+ models.policy_trainer.num_gpu=${num_device} \
+ models.policy_trainer.packing=True \
+ models.policy_trainer.meta_init=True \
+ models.policy_trainer.groupgemm=True \
+ models.policy_trainer.generation_batch_size=64 \
+ models.policy_trainer.ulysses_sequence_parallel_size=1 \
+ models.policy_trainer.load=${CHATLEARN}/pretrained_models/Qwen3-VL-30B-A3B-Instruct/ \
+ models.policy_trainer.optimizer.lr=1e-6 \
+ models.policy_trainer.pos_clip_ratio=0.2 \
+ models.policy_trainer.neg_clip_ratio=0.2 \
+ models.policy_trainer.kl_coef=0.01 \
+ models.ref_policy.generation_batch_size=64 \
+ models.policy.generation_batch_size=64 \
+ models.policy.enforce_eager=False \
+ models.policy.tensor_model_parallel_size=1 \
+ models.policy.max_prompt_tokens_length=1024 \
+ models.policy.max_response_tokens_length=2048 \
+ models.policy.num_inference_per_prompt=4 \
+ models.policy.gpu_memory_utilization=0.85 \
+ models.policy.enable_thinking=False \
+ models.reward.generation_batch_size=256 \
+ 2>&1 | tee log_${exp_name}.log ; exit ${PIPESTATUS[0]}
diff --git a/scripts/fsdp_sglang/train_fsdp_sglang_qwen3_vl_8b_grpo.sh b/scripts/fsdp_sglang/train_fsdp_sglang_qwen3_vl_8b_grpo.sh
new file mode 100644
index 00000000..3bdc63cb
--- /dev/null
+++ b/scripts/fsdp_sglang/train_fsdp_sglang_qwen3_vl_8b_grpo.sh
@@ -0,0 +1,50 @@
+#!/bin/bash
+
+# Tested on 8xH20-3e with 140G VRAM
+set -x
+
+export CHATLEARN=$(pwd)
+export PYTHONPATH=${CHATLEARN}:${PYTHONPATH}
+source scripts/base_env.sh
+export RAY_DEDUP_LOGS=1
+export exp_name=qwen3-vl-grpo-8b-sglang
+
+python chatlearn/entrypoint.py grpo \
+ --config-file template/grpo_fsdp.yaml \
+ runtime_args.exp_name=${exp_name} \
+ runtime_args.rollout_backend=sglang \
+ runtime_args.model_type=vlm \
+ runtime_args.data_path=${CHATLEARN}/dataset/geo3k/train.parquet \
+ runtime_args.eval_data_path=${CHATLEARN}/dataset/geo3k/test.parquet \
+ runtime_args.output_dir=${CHATLEARN}/output/${exp_name} \
+ runtime_args.num_episode=200 \
+ runtime_args.sample_per_episode=512 \
+ runtime_args.train_global_batch_size=512 \
+ runtime_args.train_micro_batch_size=8 \
+ runtime_args.save_episode_interval=5 \
+ runtime_args.eval_episode_interval=5 \
+ runtime_args.enable_eval_before_training=False \
+ runtime_args.log_args_dict.enable_wandb=False \
+ runtime_args.log_args_dict.wandb_project=your_wandb_project \
+ models.policy_trainer.num_gpu=${num_device} \
+ models.policy_trainer.packing=True \
+ models.policy_trainer.meta_init=False \
+ models.policy_trainer.groupgemm=False \
+ models.policy_trainer.generation_batch_size=64 \
+ models.policy_trainer.ulysses_sequence_parallel_size=1 \
+ models.policy_trainer.load=${CHATLEARN}/pretrained_models/Qwen3-VL-8B-Instruct/ \
+ models.policy_trainer.optimizer.lr=1e-6 \
+ models.policy_trainer.pos_clip_ratio=0.2 \
+ models.policy_trainer.neg_clip_ratio=0.2 \
+ models.policy_trainer.kl_coef=0.01 \
+ models.ref_policy.generation_batch_size=64 \
+ models.policy.generation_batch_size=64 \
+ models.policy.enforce_eager=False \
+ models.policy.tensor_model_parallel_size=1 \
+ models.policy.max_prompt_tokens_length=1024 \
+ models.policy.max_response_tokens_length=2048 \
+ models.policy.num_inference_per_prompt=4 \
+ models.policy.gpu_memory_utilization=0.85 \
+ models.policy.enable_thinking=False \
+ models.reward.generation_batch_size=256 \
+ 2>&1 | tee log_${exp_name}.log ; exit ${PIPESTATUS[0]}