Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion chatlearn/algorithm/grpo_utils/policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import torch.distributed as dist
import torch.nn.functional as F
from flash_attn.bert_padding import pad_input
from packaging.version import Version as PkgVersion
import transformers

from chatlearn import FSDPModule
from chatlearn.utils import to_device
Expand All @@ -37,7 +39,6 @@
split_and_unpadding,
unpad_input)


class PolicyTrainer(FSDPModule):
"""policy trainer"""
def setup(self):
Expand Down Expand Up @@ -120,6 +121,8 @@ def preprocess_data_list(self, data_list: List[Dict[str, Any]], training: bool):
if self.runtime_args.model_type == 'vlm':
# vl
position_ids = position_ids.permute(0, 2, 1).cpu()
if PkgVersion(transformers.__version__)>=PkgVersion('4.55.0'):
position_ids = torch.cat([position_ids[0:1], position_ids], dim=0) # add text position_ids for vl
else:
position_ids = position_ids.permute(1, 0).cpu() # For compatible with transformers

Expand Down Expand Up @@ -168,6 +171,19 @@ def preprocess_data_list(self, data_list: List[Dict[str, Any]], training: bool):
data_after_process.append(data_obj)
return response_token_length_total, data_after_process

def compute_vl_position_ids(self, data_list: List[Dict[str, Any]]):
input_ids_key = 'input_ids' if 'input_ids' in data_list[0] else 'prompt_token_ids'

for data_b in data_list:
position_ids, _ = self.model.model.get_rope_index(
input_ids=torch.tensor(data_b[input_ids_key]).unsqueeze(0),
image_grid_thw=data_b["image_grid_thw"],
attention_mask=torch.tensor(data_b['attention_mask']).unsqueeze(0)
)
data_b['position_ids'] = position_ids.squeeze().tolist()

return data_list

@monitor_error()
@compute_decorator(trainable=True, rollout=False)
@timeit()
Expand Down Expand Up @@ -298,6 +314,9 @@ def train_step(self, data_list: List[Dict[str, Any]], **kwargs): # pylint: disab
@compute_decorator(trainable=False, rollout=False)
@timeit()
def forward_step(self, data: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any]]: # pylint: disable=unused-argument,arguments-differ
if self.runtime_args.model_type == 'vlm':
data = self.compute_vl_position_ids(data)

_, data_list = self.preprocess_data_list(data_list=data, training=False)
tag = "old_logprobs" if self.trainable else "ref_logprobs"
# Logprobs holder
Expand Down
17 changes: 3 additions & 14 deletions chatlearn/data/vl_prompt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

from chatlearn.models.patches.transformers.qwen2_5_vl_patch import get_rope_index


class PromptPipeline(Dataset):
"""
Input data_list: List[Dict])
Expand Down Expand Up @@ -41,6 +38,7 @@ class PromptPipeline(Dataset):
"mm_processor_kwargs": {'fps':[]}, # used for video useless now
"pixel_values": Tensor, # [grid_num, pixel_num]
"image_grid_thw": Tensor, # [1,3] 3 means t,h,w
"attention_mask": List, used for compute position_ids
}
"""
def __init__(
Expand Down Expand Up @@ -98,15 +96,6 @@ def __init__(

# text only input_ids for vllm
raw_input_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
# get position_ids used for sequence packing
position_ids, _ = get_rope_index(
self.processor,
input_ids=input_ids,
image_grid_thw=model_inputs.get("image_grid_thw"),
video_grid_thw=model_inputs.get("video_grid_thw"),
second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
attention_mask=attention_mask,
)

# for vl model, raw_input_ids is only text input_ids for vllm inference
# input_ids is used for model forward_step and sglang inference (with image pad)
Expand All @@ -116,11 +105,11 @@ def __init__(
"input_ids": input_ids[0].tolist(),
"prompt_token_length": len(input_ids[0].tolist()),
"prompt": raw_prompt,
"position_ids": position_ids.squeeze().tolist(),
"multi_modal_data": multi_modal_data,
"mm_processor_kwargs": mm_processor_kwargs,
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thw
"image_grid_thw": image_grid_thw,
"attention_mask": attention_mask[0].tolist(),
})
if len(input_ids[0]) > self.max_prompt:
self.max_prompt = len(input_ids[0])
Expand Down
5 changes: 3 additions & 2 deletions chatlearn/models/agent/agent_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def postprocess_func(
prompt_token_ids = output.prompt_ids
pixel_values = output.pixel_values
image_grid_thw = output.image_grid_thw
position_ids = output.position_ids
attentiion_mask = output.attention_mask

response_token_length = len(output.all_token_ids) - len(output.prompt_ids)
prompt_token_length = len(output.prompt_ids)
str_outputs = output.str_output
Expand All @@ -108,7 +109,7 @@ def postprocess_func(
# multimodel related
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thw,
"position_ids": position_ids
"attention_mask": attentiion_mask
}
)
data_output.append(input_data)
Expand Down
20 changes: 4 additions & 16 deletions chatlearn/models/agent/base_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
find_last_ai_index,
find_first_ai_index)
from chatlearn.models.sglang_module import AsyncEngine
from chatlearn.models.patches.transformers.qwen2_5_vl_patch import get_rope_index


def find_first_zero_group_end(lst):
for i, x in enumerate(lst):
if x != 0:
Expand All @@ -55,7 +52,7 @@ class AgentGraphOutput(BaseModel):
# multimodel related item
pixel_values: Any = None
image_grid_thw: Any = None
position_ids: Any = None
attention_mask: Any = None
# Extra fields for dynamic addition.
extra_fields: dict[str, Any] = {}

Expand Down Expand Up @@ -108,21 +105,12 @@ def convert_agent_graph_output(self, messages: Dict) -> AgentGraphOutput:
num_turns = last_ai_message_idx + 1
str_output = self.tokenizer.decode(all_token_ids[prompt_end_idx + 1 :])

pixel_values, image_grid_thw, position_ids = None, None, None
pixel_values, image_grid_thw, attention_mask = None, None, None
multimodel_batch_feature = messages[first_ai_message_idx].response_metadata.get("multimodel_batch_feature", None)
if multimodel_batch_feature:
pixel_values = multimodel_batch_feature.get("pixel_values")
image_grid_thw = multimodel_batch_feature.get("image_grid_thw")
# need to get position ids used in sequence packing
position_ids, _ = get_rope_index(
self.processor,
input_ids=multimodel_batch_feature.get("input_ids"),
image_grid_thw=multimodel_batch_feature.get("image_grid_thw"),
video_grid_thw=multimodel_batch_feature.get("video_grid_thw"),
second_per_grid_ts=multimodel_batch_feature.get("second_per_grid_ts"),
attention_mask=multimodel_batch_feature.get("attention_mask"),
)
position_ids = position_ids.squeeze().tolist()
attention_mask = multimodel_batch_feature.get("attention_mask")[0].tolist()

return AgentGraphOutput(
str_output=str_output,
Expand All @@ -132,5 +120,5 @@ def convert_agent_graph_output(self, messages: Dict) -> AgentGraphOutput:
num_turns=num_turns,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
position_ids=position_ids
attention_mask=attention_mask
)
44 changes: 29 additions & 15 deletions chatlearn/models/fsdp_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import numpy as np
from packaging.version import Version as PkgVersion
import transformers
import torch
from torch import Tensor
import torch.distributed as dist
Expand Down Expand Up @@ -258,10 +257,9 @@ def create_model(self, model_path: str , torch_dtype: torch.dtype, meta_init: bo
attn_implementation="flash_attention_2",
trust_remote_code=self.module_args.trust_remote_code
)
if PkgVersion(transformers.__version__)==PkgVersion('4.51.3'):
# vl patch needed for transformers 4.51.3
from chatlearn.models.patches.monkey_patch import apply_qwenvl
apply_qwenvl(model)

from chatlearn.models.patches.monkey_patch import apply_qwenvl
apply_qwenvl(model)

assert self.sp_size == 1, "VL model only support sp_size=1"
else:
Expand All @@ -273,14 +271,26 @@ def create_model(self, model_path: str , torch_dtype: torch.dtype, meta_init: bo
)
else:
model_config = AutoConfig.from_pretrained(model_path)
assert "Qwen2_5_VLForConditionalGeneration" not in model_config.architectures, "VL model not support meta init"
with init_on_device('meta', include_buffers=False):
model = AutoModelForCausalLM.from_config(
model_config,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
trust_remote_code=self.module_args.trust_remote_code
)

if self.runtime_args.model_type == 'vlm':
with init_on_device('meta', include_buffers=False):
model = AutoModelForImageTextToText.from_pretrained(
pretrained_model_name_or_path=model_path,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
trust_remote_code=self.module_args.trust_remote_code
)

from chatlearn.models.patches.monkey_patch import apply_qwenvl
apply_qwenvl(model)
else:
with init_on_device('meta', include_buffers=False):
model = AutoModelForCausalLM.from_config(
model_config,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
trust_remote_code=self.module_args.trust_remote_code
)
dist.barrier()
return model
@property
Expand Down Expand Up @@ -507,9 +517,13 @@ def get_weight_ipc_handles_by_name(self, block_name: List[str]):
if rollout_engine == "sglang":
# lazy import sglang
from sglang.srt.utils import MultiprocessingSerializer
from sglang.srt.patch_torch import monkey_patch_torch_reductions

import sglang
if PkgVersion(sglang.__version__)>=PkgVersion('0.5.3'):
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
else:
from sglang.srt.patch_torch import monkey_patch_torch_reductions
monkey_patch_torch_reductions()

flattened_tensor, metadatas = self.convert_block2flattened_bucket(
block_parameter
)
Expand Down
14 changes: 12 additions & 2 deletions chatlearn/models/patches/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.
# ==============================================================================
"""Apply patches for different model architectures"""
from packaging.version import Version as PkgVersion
import transformers

def apply_sp_monkey_patch(model_config):
print(f"applying sequence parallel patches for {model_config.architectures}")
if model_config.architectures[0] == "Qwen2ForCausalLM":
Expand Down Expand Up @@ -42,8 +45,15 @@ def apply_group_gemm(model):
def apply_qwenvl(model):
print(f"applying qwenvl patches for {model.config.architectures[0]}")
if model.config.architectures[0] == "Qwen2_5_VLForConditionalGeneration":
from chatlearn.models.patches.transformers.qwen2_5_vl_patch import apply_qwenvl_patch \
if PkgVersion(transformers.__version__)==PkgVersion('4.51.3'):
# vl2.5 patch needed for transformers 4.51.3
from chatlearn.models.patches.transformers.qwen2_5_vl_patch import apply_qwenvl_patch \
# pylint: disable=import-outside-toplevel
apply_qwenvl_patch()
elif model.config.architectures[0] in ["Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration"]:
assert PkgVersion(transformers.__version__)>=PkgVersion('4.57.0'), "qwen3vl needed transformers >= 4.57.0"
from chatlearn.models.patches.transformers.qwen3_vl_patch import apply_qwen3vl_patch \
# pylint: disable=import-outside-toplevel
apply_qwenvl_patch()
apply_qwen3vl_patch()
else:
raise ValueError(f"Unsupported model architecture: {model.config.architectures} for qwenvl patch")
33 changes: 33 additions & 0 deletions chatlearn/models/patches/transformers/qwen3_vl_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""patches for qwen3 vl model"""
from typing import Optional
import torch

def Qwen3VLBlock_patched_forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
# =========================================================================
# add force dype change for qwen3_vl or backward will occur type error
hidden_states = hidden_states.to(self.norm1.weight.dtype)
# =========================================================================

hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states

def apply_qwen3vl_patch():
# pylint: disable=import-outside-toplevel
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionBlock
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionBlock
Qwen3VLVisionBlock.forward = Qwen3VLBlock_patched_forward
Qwen3VLMoeVisionBlock.forward = Qwen3VLBlock_patched_forward
12 changes: 10 additions & 2 deletions chatlearn/models/sglang_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,12 @@ def parameter_sync(self):
def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']]):
"""Used for Mcore2SGLang Parameter Sync
"""
from sglang.srt.patch_torch import monkey_patch_torch_reductions
if PkgVersion(sglang.__version__)>=PkgVersion('0.5.3'):
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
else:
from sglang.srt.patch_torch import monkey_patch_torch_reductions
monkey_patch_torch_reductions()

param_id_to_update = set()
for bucket in buckets:
if bucket is None:
Expand Down Expand Up @@ -766,8 +770,12 @@ async def update_weights_from_ipc_handles(self, reduce_data):

@torch.no_grad()
async def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']]):
from sglang.srt.patch_torch import monkey_patch_torch_reductions
if PkgVersion(sglang.__version__)>=PkgVersion('0.5.3'):
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
else:
from sglang.srt.patch_torch import monkey_patch_torch_reductions
monkey_patch_torch_reductions()

param_id_to_update = set()
for bucket in buckets:
if bucket is None:
Expand Down
15 changes: 15 additions & 0 deletions docker/torch/Dockerfile.torch2.8.0.sglang053
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
FROM nvcr.io/nvidia/pytorch:24.12-py3
ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/
ENV PIP_TRUSTED_HOST=mirrors.aliyun.com
RUN pip install --no-cache-dir "sglang[all]==0.5.3.post1"
RUN pip install --no-cache-dir transformers==4.57.0
RUN pip install --no-cache-dir langgraph==0.6.6
RUN pip install --no-cache-dir ray[default]==2.46.0
RUN pip install --no-cache-dir accelerate==1.10.0
RUN pip install --no-cache-dir wandb==0.19.3
RUN pip install --no-cache-dir hydra-core==1.3.2
RUN pip install --no-cache-dir grpcio==1.70.0 nvidia-modelopt==0.27.0 nvidia-modelopt-core==0.27.0 datasets==3.6.0 deepspeed==0.16.7
RUN pip install --no-cache-dir mathruler==0.1.0 pylatexenc==2.10 qwen-vl-utils==0.0.14
RUN pip uninstall -y flash_attn && pip install --no-cache-dir https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/flash-attention/torch2.8.0-cu12x/flash_attn-2.7.4.post1-cp312-cp312-linux_x86_64.whl
RUN pip uninstall -y transformer_engine && pip install --no-cache-dir https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/transformer_engine/torch2.8.0-cuda12x/transformer_engine-2.3.0%2B5de3e148-cp312-cp312-linux_x86_64.whl
RUN pip uninstall -y apex && pip install --no-cache-dir https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/apex/torch2.8.0-cuda12x/apex-0.1-cp312-cp312-linux_x86_64.whl
12 changes: 10 additions & 2 deletions docs/en/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,19 @@ dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2

### SGLang

You can prepare the image by referring to [Dockerfile.torch2.8.0.sglang052](https://github.com/alibaba/ChatLearn/blob/main/docker/torch/Dockerfile.torch2.8.0.sglang052). Alternatively, you can directly pull and use the following image:
You can prepare the image by referring to [Dockerfile.torch2.8.0.sglang052](https://github.com/alibaba/ChatLearn/blob/main/docker/torch/Dockerfile.torch2.8.0.sglang053). Alternatively, you can directly pull and use the following image:

```bash
dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.2-ubuntu24.04-cuda12.6-py312
dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.3-ubuntu24.04-cuda12.6-py312
```
### Image History

| Image URL | Pkg Version | Model List |
| ------------------------------------------------------------ | ----------------------------------------- | ------------------------------------------ |
| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.3-ubuntu24.04-cuda12.6-py312 | sglang 0.5.3.post1<br>transformers 4.57.0 | Qwen3-VL<br>Qwen2.5-VL<br>Qwen3<br>Qwen2.5 |
| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.8.0-sglang0.5.3-ubuntu24.04-cuda12.6-py312 | sglang 0.5.2<br>transformers 4.56.1 | Qwen2.5-VL<br>Qwen3<br>Qwen2.5 |
| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-te2.7-ubuntu24.04-cuda12.6-py312 | vllm 0.8.5<br>transformer_engine 2.7 | Moonlight<br>Deepseek-r1 |
| dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312 | vllm 0.8.5<br>transformers 4.51.3 | Qwen2.5-VL<br>Qwen3<br/>Qwen2.5 |

## 2. Code Preparation

Expand Down
Loading