diff --git a/chatlearn/algorithm/grpo_utils/policy_trainer.py b/chatlearn/algorithm/grpo_utils/policy_trainer.py index b5d17fd4..b636cb25 100644 --- a/chatlearn/algorithm/grpo_utils/policy_trainer.py +++ b/chatlearn/algorithm/grpo_utils/policy_trainer.py @@ -22,6 +22,8 @@ import torch.distributed as dist import torch.nn.functional as F from flash_attn.bert_padding import pad_input +from packaging.version import Version as PkgVersion +import transformers from chatlearn import FSDPModule from chatlearn.utils import to_device @@ -37,7 +39,6 @@ split_and_unpadding, unpad_input) - class PolicyTrainer(FSDPModule): """policy trainer""" def setup(self): @@ -120,6 +121,8 @@ def preprocess_data_list(self, data_list: List[Dict[str, Any]], training: bool): if self.runtime_args.model_type == 'vlm': # vl position_ids = position_ids.permute(0, 2, 1).cpu() + if PkgVersion(transformers.__version__)>=PkgVersion('4.55.0'): + position_ids = torch.cat([position_ids[0:1], position_ids], dim=0) # add text position_ids for vl else: position_ids = position_ids.permute(1, 0).cpu() # For compatible with transformers @@ -168,6 +171,19 @@ def preprocess_data_list(self, data_list: List[Dict[str, Any]], training: bool): data_after_process.append(data_obj) return response_token_length_total, data_after_process + def compute_vl_position_ids(self, data_list: List[Dict[str, Any]]): + input_ids_key = 'input_ids' if 'input_ids' in data_list[0] else 'prompt_token_ids' + + for data_b in data_list: + position_ids, _ = self.model.model.get_rope_index( + input_ids=torch.tensor(data_b[input_ids_key]).unsqueeze(0), + image_grid_thw=data_b["image_grid_thw"], + attention_mask=torch.tensor(data_b['attention_mask']).unsqueeze(0) + ) + data_b['position_ids'] = position_ids.squeeze().tolist() + + return data_list + @monitor_error() @compute_decorator(trainable=True, rollout=False) @timeit() @@ -298,6 +314,9 @@ def train_step(self, data_list: List[Dict[str, Any]], **kwargs): # pylint: disab @compute_decorator(trainable=False, rollout=False) @timeit() def forward_step(self, data: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any]]: # pylint: disable=unused-argument,arguments-differ + if self.runtime_args.model_type == 'vlm': + data = self.compute_vl_position_ids(data) + _, data_list = self.preprocess_data_list(data_list=data, training=False) tag = "old_logprobs" if self.trainable else "ref_logprobs" # Logprobs holder diff --git a/chatlearn/data/vl_prompt_dataset.py b/chatlearn/data/vl_prompt_dataset.py index 0d4e69b9..b97d1c38 100644 --- a/chatlearn/data/vl_prompt_dataset.py +++ b/chatlearn/data/vl_prompt_dataset.py @@ -6,9 +6,6 @@ from transformers import AutoTokenizer, AutoProcessor from qwen_vl_utils import process_vision_info -from chatlearn.models.patches.transformers.qwen2_5_vl_patch import get_rope_index - - class PromptPipeline(Dataset): """ Input data_list: List[Dict]) @@ -41,6 +38,7 @@ class PromptPipeline(Dataset): "mm_processor_kwargs": {'fps':[]}, # used for video useless now "pixel_values": Tensor, # [grid_num, pixel_num] "image_grid_thw": Tensor, # [1,3] 3 means t,h,w + "attention_mask": List, used for compute position_ids } """ def __init__( @@ -98,15 +96,6 @@ def __init__( # text only input_ids for vllm raw_input_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False) - # get position_ids used for sequence packing - position_ids, _ = get_rope_index( - self.processor, - input_ids=input_ids, - image_grid_thw=model_inputs.get("image_grid_thw"), - video_grid_thw=model_inputs.get("video_grid_thw"), - second_per_grid_ts=model_inputs.get("second_per_grid_ts"), - attention_mask=attention_mask, - ) # for vl model, raw_input_ids is only text input_ids for vllm inference # input_ids is used for model forward_step and sglang inference (with image pad) @@ -116,11 +105,11 @@ def __init__( "input_ids": input_ids[0].tolist(), "prompt_token_length": len(input_ids[0].tolist()), "prompt": raw_prompt, - "position_ids": position_ids.squeeze().tolist(), "multi_modal_data": multi_modal_data, "mm_processor_kwargs": mm_processor_kwargs, "pixel_values": pixel_values, - "image_grid_thw": image_grid_thw + "image_grid_thw": image_grid_thw, + "attention_mask": attention_mask[0].tolist(), }) if len(input_ids[0]) > self.max_prompt: self.max_prompt = len(input_ids[0]) diff --git a/chatlearn/models/agent/agent_module.py b/chatlearn/models/agent/agent_module.py index 98aceda9..ab74b368 100644 --- a/chatlearn/models/agent/agent_module.py +++ b/chatlearn/models/agent/agent_module.py @@ -89,7 +89,8 @@ def postprocess_func( prompt_token_ids = output.prompt_ids pixel_values = output.pixel_values image_grid_thw = output.image_grid_thw - position_ids = output.position_ids + attentiion_mask = output.attention_mask + response_token_length = len(output.all_token_ids) - len(output.prompt_ids) prompt_token_length = len(output.prompt_ids) str_outputs = output.str_output @@ -108,7 +109,7 @@ def postprocess_func( # multimodel related "pixel_values": pixel_values, "image_grid_thw": image_grid_thw, - "position_ids": position_ids + "attention_mask": attentiion_mask } ) data_output.append(input_data) diff --git a/chatlearn/models/agent/base_agent_graph.py b/chatlearn/models/agent/base_agent_graph.py index f6532220..b9f7a9c6 100644 --- a/chatlearn/models/agent/base_agent_graph.py +++ b/chatlearn/models/agent/base_agent_graph.py @@ -26,9 +26,6 @@ find_last_ai_index, find_first_ai_index) from chatlearn.models.sglang_module import AsyncEngine -from chatlearn.models.patches.transformers.qwen2_5_vl_patch import get_rope_index - - def find_first_zero_group_end(lst): for i, x in enumerate(lst): if x != 0: @@ -55,7 +52,7 @@ class AgentGraphOutput(BaseModel): # multimodel related item pixel_values: Any = None image_grid_thw: Any = None - position_ids: Any = None + attention_mask: Any = None # Extra fields for dynamic addition. extra_fields: dict[str, Any] = {} @@ -108,21 +105,12 @@ def convert_agent_graph_output(self, messages: Dict) -> AgentGraphOutput: num_turns = last_ai_message_idx + 1 str_output = self.tokenizer.decode(all_token_ids[prompt_end_idx + 1 :]) - pixel_values, image_grid_thw, position_ids = None, None, None + pixel_values, image_grid_thw, attention_mask = None, None, None multimodel_batch_feature = messages[first_ai_message_idx].response_metadata.get("multimodel_batch_feature", None) if multimodel_batch_feature: pixel_values = multimodel_batch_feature.get("pixel_values") image_grid_thw = multimodel_batch_feature.get("image_grid_thw") - # need to get position ids used in sequence packing - position_ids, _ = get_rope_index( - self.processor, - input_ids=multimodel_batch_feature.get("input_ids"), - image_grid_thw=multimodel_batch_feature.get("image_grid_thw"), - video_grid_thw=multimodel_batch_feature.get("video_grid_thw"), - second_per_grid_ts=multimodel_batch_feature.get("second_per_grid_ts"), - attention_mask=multimodel_batch_feature.get("attention_mask"), - ) - position_ids = position_ids.squeeze().tolist() + attention_mask = multimodel_batch_feature.get("attention_mask")[0].tolist() return AgentGraphOutput( str_output=str_output, @@ -132,5 +120,5 @@ def convert_agent_graph_output(self, messages: Dict) -> AgentGraphOutput: num_turns=num_turns, pixel_values=pixel_values, image_grid_thw=image_grid_thw, - position_ids=position_ids + attention_mask=attention_mask ) diff --git a/chatlearn/models/fsdp_module.py b/chatlearn/models/fsdp_module.py index 9709e2b7..1ecdb313 100644 --- a/chatlearn/models/fsdp_module.py +++ b/chatlearn/models/fsdp_module.py @@ -23,7 +23,6 @@ import numpy as np from packaging.version import Version as PkgVersion -import transformers import torch from torch import Tensor import torch.distributed as dist @@ -258,10 +257,9 @@ def create_model(self, model_path: str , torch_dtype: torch.dtype, meta_init: bo attn_implementation="flash_attention_2", trust_remote_code=self.module_args.trust_remote_code ) - if PkgVersion(transformers.__version__)==PkgVersion('4.51.3'): - # vl patch needed for transformers 4.51.3 - from chatlearn.models.patches.monkey_patch import apply_qwenvl - apply_qwenvl(model) + + from chatlearn.models.patches.monkey_patch import apply_qwenvl + apply_qwenvl(model) assert self.sp_size == 1, "VL model only support sp_size=1" else: @@ -273,14 +271,26 @@ def create_model(self, model_path: str , torch_dtype: torch.dtype, meta_init: bo ) else: model_config = AutoConfig.from_pretrained(model_path) - assert "Qwen2_5_VLForConditionalGeneration" not in model_config.architectures, "VL model not support meta init" - with init_on_device('meta', include_buffers=False): - model = AutoModelForCausalLM.from_config( - model_config, - torch_dtype=torch_dtype, - attn_implementation="flash_attention_2", - trust_remote_code=self.module_args.trust_remote_code - ) + + if self.runtime_args.model_type == 'vlm': + with init_on_device('meta', include_buffers=False): + model = AutoModelForImageTextToText.from_pretrained( + pretrained_model_name_or_path=model_path, + torch_dtype=torch_dtype, + attn_implementation="flash_attention_2", + trust_remote_code=self.module_args.trust_remote_code + ) + + from chatlearn.models.patches.monkey_patch import apply_qwenvl + apply_qwenvl(model) + else: + with init_on_device('meta', include_buffers=False): + model = AutoModelForCausalLM.from_config( + model_config, + torch_dtype=torch_dtype, + attn_implementation="flash_attention_2", + trust_remote_code=self.module_args.trust_remote_code + ) dist.barrier() return model @property @@ -507,9 +517,13 @@ def get_weight_ipc_handles_by_name(self, block_name: List[str]): if rollout_engine == "sglang": # lazy import sglang from sglang.srt.utils import MultiprocessingSerializer - from sglang.srt.patch_torch import monkey_patch_torch_reductions - + import sglang + if PkgVersion(sglang.__version__)>=PkgVersion('0.5.3'): + from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions + else: + from sglang.srt.patch_torch import monkey_patch_torch_reductions monkey_patch_torch_reductions() + flattened_tensor, metadatas = self.convert_block2flattened_bucket( block_parameter ) diff --git a/chatlearn/models/patches/monkey_patch.py b/chatlearn/models/patches/monkey_patch.py index a0be47d1..c7c281bb 100644 --- a/chatlearn/models/patches/monkey_patch.py +++ b/chatlearn/models/patches/monkey_patch.py @@ -13,6 +13,9 @@ # limitations under the License. # ============================================================================== """Apply patches for different model architectures""" +from packaging.version import Version as PkgVersion +import transformers + def apply_sp_monkey_patch(model_config): print(f"applying sequence parallel patches for {model_config.architectures}") if model_config.architectures[0] == "Qwen2ForCausalLM": @@ -42,8 +45,15 @@ def apply_group_gemm(model): def apply_qwenvl(model): print(f"applying qwenvl patches for {model.config.architectures[0]}") if model.config.architectures[0] == "Qwen2_5_VLForConditionalGeneration": - from chatlearn.models.patches.transformers.qwen2_5_vl_patch import apply_qwenvl_patch \ + if PkgVersion(transformers.__version__)==PkgVersion('4.51.3'): + # vl2.5 patch needed for transformers 4.51.3 + from chatlearn.models.patches.transformers.qwen2_5_vl_patch import apply_qwenvl_patch \ + # pylint: disable=import-outside-toplevel + apply_qwenvl_patch() + elif model.config.architectures[0] in ["Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration"]: + assert PkgVersion(transformers.__version__)>=PkgVersion('4.57.0'), "qwen3vl needed transformers >= 4.57.0" + from chatlearn.models.patches.transformers.qwen3_vl_patch import apply_qwen3vl_patch \ # pylint: disable=import-outside-toplevel - apply_qwenvl_patch() + apply_qwen3vl_patch() else: raise ValueError(f"Unsupported model architecture: {model.config.architectures} for qwenvl patch") diff --git a/chatlearn/models/patches/transformers/qwen3_vl_patch.py b/chatlearn/models/patches/transformers/qwen3_vl_patch.py new file mode 100644 index 00000000..c1145ff7 --- /dev/null +++ b/chatlearn/models/patches/transformers/qwen3_vl_patch.py @@ -0,0 +1,33 @@ +"""patches for qwen3 vl model""" +from typing import Optional +import torch + +def Qwen3VLBlock_patched_forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + # ========================================================================= + # add force dype change for qwen3_vl or backward will occur type error + hidden_states = hidden_states.to(self.norm1.weight.dtype) + # ========================================================================= + + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + +def apply_qwen3vl_patch(): + # pylint: disable=import-outside-toplevel + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionBlock + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionBlock + Qwen3VLVisionBlock.forward = Qwen3VLBlock_patched_forward + Qwen3VLMoeVisionBlock.forward = Qwen3VLBlock_patched_forward diff --git a/chatlearn/models/sglang_module.py b/chatlearn/models/sglang_module.py index dfb91970..1ded1602 100644 --- a/chatlearn/models/sglang_module.py +++ b/chatlearn/models/sglang_module.py @@ -570,8 +570,12 @@ def parameter_sync(self): def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']]): """Used for Mcore2SGLang Parameter Sync """ - from sglang.srt.patch_torch import monkey_patch_torch_reductions + if PkgVersion(sglang.__version__)>=PkgVersion('0.5.3'): + from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions + else: + from sglang.srt.patch_torch import monkey_patch_torch_reductions monkey_patch_torch_reductions() + param_id_to_update = set() for bucket in buckets: if bucket is None: @@ -766,8 +770,12 @@ async def update_weights_from_ipc_handles(self, reduce_data): @torch.no_grad() async def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']]): - from sglang.srt.patch_torch import monkey_patch_torch_reductions + if PkgVersion(sglang.__version__)>=PkgVersion('0.5.3'): + from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions + else: + from sglang.srt.patch_torch import monkey_patch_torch_reductions monkey_patch_torch_reductions() + param_id_to_update = set() for bucket in buckets: if bucket is None: diff --git a/docker/torch/Dockerfile.torch2.8.0.sglang053 b/docker/torch/Dockerfile.torch2.8.0.sglang053 new file mode 100644 index 00000000..085e044d --- /dev/null +++ b/docker/torch/Dockerfile.torch2.8.0.sglang053 @@ -0,0 +1,15 @@ +FROM nvcr.io/nvidia/pytorch:24.12-py3 +ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/ +ENV PIP_TRUSTED_HOST=mirrors.aliyun.com +RUN pip install --no-cache-dir "sglang[all]==0.5.3.post1" +RUN pip install --no-cache-dir transformers==4.57.0 +RUN pip install --no-cache-dir langgraph==0.6.6 +RUN pip install --no-cache-dir ray[default]==2.46.0 +RUN pip install --no-cache-dir accelerate==1.10.0 +RUN pip install --no-cache-dir wandb==0.19.3 +RUN pip install --no-cache-dir hydra-core==1.3.2 +RUN pip install --no-cache-dir grpcio==1.70.0 nvidia-modelopt==0.27.0 nvidia-modelopt-core==0.27.0 datasets==3.6.0 deepspeed==0.16.7 +RUN pip install --no-cache-dir mathruler==0.1.0 pylatexenc==2.10 qwen-vl-utils==0.0.14 +RUN pip uninstall -y flash_attn && pip install --no-cache-dir https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/flash-attention/torch2.8.0-cu12x/flash_attn-2.7.4.post1-cp312-cp312-linux_x86_64.whl +RUN pip uninstall -y transformer_engine && pip install --no-cache-dir https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/transformer_engine/torch2.8.0-cuda12x/transformer_engine-2.3.0%2B5de3e148-cp312-cp312-linux_x86_64.whl +RUN pip uninstall -y apex && pip install --no-cache-dir https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/apex/torch2.8.0-cuda12x/apex-0.1-cp312-cp312-linux_x86_64.whl \ No newline at end of file diff --git a/docs/en/installation.md b/docs/en/installation.md index 8c8aa9a0..81ef42ca 100644 --- a/docs/en/installation.md +++ b/docs/en/installation.md @@ -14,11 +14,19 @@ dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2 ### SGLang -You can prepare the image by referring to [Dockerfile.torch2.8.0.sglang052](https://github.com/alibaba/ChatLearn/blob/main/docker/torch/Dockerfile.torch2.8.0.sglang052). Alternatively, you can directly pull and use the following image: +You can prepare the image by referring to [Dockerfile.torch2.8.0.sglang052](https://github.com/alibaba/ChatLearn/blob/main/docker/torch/Dockerfile.torch2.8.0.sglang053). Alternatively, you can directly pull and use the following image: ```bash -dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.2-ubuntu24.04-cuda12.6-py312 +dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.3-ubuntu24.04-cuda12.6-py312 ``` +### Image History + +| Image URL | Pkg Version | Model List | +| ------------------------------------------------------------ | ----------------------------------------- | ------------------------------------------ | +| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.3-ubuntu24.04-cuda12.6-py312 | sglang 0.5.3.post1
transformers 4.57.0 | Qwen3-VL
Qwen2.5-VL
Qwen3
Qwen2.5 | +| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.3-ubuntu24.04-cuda12.6-py312 | sglang 0.5.2
transformers 4.56.1 | Qwen2.5-VL
Qwen3
Qwen2.5 | +| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-te2.7-ubuntu24.04-cuda12.6-py312 | vllm 0.8.5
transformer_engine 2.7 | Moonlight
Deepseek-r1 | +| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312 | vllm 0.8.5
transformers 4.51.3 | Qwen2.5-VL
Qwen3
Qwen2.5 | ## 2. Code Preparation diff --git a/docs/zh/installation.md b/docs/zh/installation.md index befa2ac5..5aa8db78 100644 --- a/docs/zh/installation.md +++ b/docs/zh/installation.md @@ -13,12 +13,21 @@ dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2 ### SGLang -可以参考 [Dockerfile.torch2.8.0.sglang052](https://github.com/alibaba/ChatLearn/blob/main/docker/torch/Dockerfile.torch2.8.0.sglang052) 准备镜像。也可以直接拉取如下镜像地址直接进行使用。 +可以参考 [Dockerfile.torch2.8.0.sglang053](https://github.com/alibaba/ChatLearn/blob/main/docker/torch/Dockerfile.torch2.8.0.sglang053) 准备镜像。也可以直接拉取如下镜像地址直接进行使用。 ```bash -dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.2-ubuntu24.04-cuda12.6-py312 +dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.3-ubuntu24.04-cuda12.6-py312 ``` +### 镜像历史 + +| 镜像地址 | 包版本 | 模型列表 | +| ------------------------------------------------------------ | ----------------------------------------- | ------------------------------------------ | +| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.3-ubuntu24.04-cuda12.6-py312 | sglang 0.5.3.post1
transformers 4.57.0 | Qwen3-VL
Qwen2.5-VL
Qwen3
Qwen2.5 | +| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.3-ubuntu24.04-cuda12.6-py312 | sglang 0.5.2
transformers 4.56.1 | Qwen2.5-VL
Qwen3
Qwen2.5 | +| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-te2.7-ubuntu24.04-cuda12.6-py312 | vllm 0.8.5
transformer_engine 2.7 | Moonlight
Deepseek-r1 | +| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312 | vllm 0.8.5
transformers 4.51.3 | Qwen2.5-VL
Qwen3
Qwen2.5 | + ## 2. 代码准备 ``` diff --git a/scripts/fsdp_sglang/train_fsdp_sglang_qwen3_vl_30b_grpo.sh b/scripts/fsdp_sglang/train_fsdp_sglang_qwen3_vl_30b_grpo.sh new file mode 100644 index 00000000..af60f23f --- /dev/null +++ b/scripts/fsdp_sglang/train_fsdp_sglang_qwen3_vl_30b_grpo.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +# Tested on 8xH20-3e with 140G VRAM +set -x + +export CHATLEARN=$(pwd) +export PYTHONPATH=${CHATLEARN}:${PYTHONPATH} +source scripts/base_env.sh +export RAY_DEDUP_LOGS=1 +export exp_name=qwen3-vl-grpo-30b-sglang + +python chatlearn/entrypoint.py grpo \ + --config-file template/grpo_fsdp.yaml \ + runtime_args.exp_name=${exp_name} \ + runtime_args.rollout_backend=sglang \ + runtime_args.model_type=vlm \ + runtime_args.data_path=${CHATLEARN}/dataset/geo3k/train.parquet \ + runtime_args.eval_data_path=${CHATLEARN}/dataset/geo3k/test.parquet \ + runtime_args.output_dir=${CHATLEARN}/output/${exp_name} \ + runtime_args.num_episode=200 \ + runtime_args.sample_per_episode=512 \ + runtime_args.train_global_batch_size=512 \ + runtime_args.train_micro_batch_size=8 \ + runtime_args.save_episode_interval=5 \ + runtime_args.eval_episode_interval=5 \ + runtime_args.enable_eval_before_training=False \ + runtime_args.log_args_dict.enable_wandb=False \ + runtime_args.log_args_dict.wandb_project=your_wandb_project \ + models.policy_trainer.num_gpu=${num_device} \ + models.policy_trainer.packing=True \ + models.policy_trainer.meta_init=True \ + models.policy_trainer.groupgemm=True \ + models.policy_trainer.generation_batch_size=64 \ + models.policy_trainer.ulysses_sequence_parallel_size=1 \ + models.policy_trainer.load=${CHATLEARN}/pretrained_models/Qwen3-VL-30B-A3B-Instruct/ \ + models.policy_trainer.optimizer.lr=1e-6 \ + models.policy_trainer.pos_clip_ratio=0.2 \ + models.policy_trainer.neg_clip_ratio=0.2 \ + models.policy_trainer.kl_coef=0.01 \ + models.ref_policy.generation_batch_size=64 \ + models.policy.generation_batch_size=64 \ + models.policy.enforce_eager=False \ + models.policy.tensor_model_parallel_size=1 \ + models.policy.max_prompt_tokens_length=1024 \ + models.policy.max_response_tokens_length=2048 \ + models.policy.num_inference_per_prompt=4 \ + models.policy.gpu_memory_utilization=0.85 \ + models.policy.enable_thinking=False \ + models.reward.generation_batch_size=256 \ + 2>&1 | tee log_${exp_name}.log ; exit ${PIPESTATUS[0]} diff --git a/scripts/fsdp_sglang/train_fsdp_sglang_qwen3_vl_8b_grpo.sh b/scripts/fsdp_sglang/train_fsdp_sglang_qwen3_vl_8b_grpo.sh new file mode 100644 index 00000000..3bdc63cb --- /dev/null +++ b/scripts/fsdp_sglang/train_fsdp_sglang_qwen3_vl_8b_grpo.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +# Tested on 8xH20-3e with 140G VRAM +set -x + +export CHATLEARN=$(pwd) +export PYTHONPATH=${CHATLEARN}:${PYTHONPATH} +source scripts/base_env.sh +export RAY_DEDUP_LOGS=1 +export exp_name=qwen3-vl-grpo-8b-sglang + +python chatlearn/entrypoint.py grpo \ + --config-file template/grpo_fsdp.yaml \ + runtime_args.exp_name=${exp_name} \ + runtime_args.rollout_backend=sglang \ + runtime_args.model_type=vlm \ + runtime_args.data_path=${CHATLEARN}/dataset/geo3k/train.parquet \ + runtime_args.eval_data_path=${CHATLEARN}/dataset/geo3k/test.parquet \ + runtime_args.output_dir=${CHATLEARN}/output/${exp_name} \ + runtime_args.num_episode=200 \ + runtime_args.sample_per_episode=512 \ + runtime_args.train_global_batch_size=512 \ + runtime_args.train_micro_batch_size=8 \ + runtime_args.save_episode_interval=5 \ + runtime_args.eval_episode_interval=5 \ + runtime_args.enable_eval_before_training=False \ + runtime_args.log_args_dict.enable_wandb=False \ + runtime_args.log_args_dict.wandb_project=your_wandb_project \ + models.policy_trainer.num_gpu=${num_device} \ + models.policy_trainer.packing=True \ + models.policy_trainer.meta_init=False \ + models.policy_trainer.groupgemm=False \ + models.policy_trainer.generation_batch_size=64 \ + models.policy_trainer.ulysses_sequence_parallel_size=1 \ + models.policy_trainer.load=${CHATLEARN}/pretrained_models/Qwen3-VL-8B-Instruct/ \ + models.policy_trainer.optimizer.lr=1e-6 \ + models.policy_trainer.pos_clip_ratio=0.2 \ + models.policy_trainer.neg_clip_ratio=0.2 \ + models.policy_trainer.kl_coef=0.01 \ + models.ref_policy.generation_batch_size=64 \ + models.policy.generation_batch_size=64 \ + models.policy.enforce_eager=False \ + models.policy.tensor_model_parallel_size=1 \ + models.policy.max_prompt_tokens_length=1024 \ + models.policy.max_response_tokens_length=2048 \ + models.policy.num_inference_per_prompt=4 \ + models.policy.gpu_memory_utilization=0.85 \ + models.policy.enable_thinking=False \ + models.reward.generation_batch_size=256 \ + 2>&1 | tee log_${exp_name}.log ; exit ${PIPESTATUS[0]}