Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
7298039
support trace_agg_mode
JiahangXu Oct 9, 2025
fb08c5f
remove breakpoint and fix conner case
JiahangXu Oct 11, 2025
47f5186
Merge branch 'main' into dev/support_mask
JiahangXu Oct 11, 2025
dfb6323
reformatted daemon
JiahangXu Oct 11, 2025
0d2dece
Merge branch 'dev/support_mask' of github.com:microsoft/agent-lightni…
JiahangXu Oct 11, 2025
76fedd6
add fuzzy_startswith to support special_token_tolerance and string_to…
JiahangXu Oct 27, 2025
ee253db
Merge branch 'main' into dev/support_mask
JiahangXu Nov 4, 2025
a29f5f4
refactor to trace_aggregator
JiahangXu Nov 5, 2025
a77075e
Merge branch 'main' into dev/support_mask
JiahangXu Nov 5, 2025
816c8ed
fix typo
JiahangXu Nov 5, 2025
20ff8ba
add logs, fix mask mapping
JiahangXu Nov 5, 2025
d93c2cc
fix typo
JiahangXu Nov 5, 2025
672f037
fix pylint error
JiahangXu Nov 5, 2025
364d539
fix pylint error
JiahangXu Nov 5, 2025
c8a8b93
Update Search-R1 Example to v0.2.x
SiyunZhao Nov 5, 2025
db9d280
delete redundant script
SiyunZhao Nov 5, 2025
5d551ad
add response id error log, convert to gen response
JiahangXu Nov 10, 2025
f60a49b
fix path
SiyunZhao Nov 13, 2025
5e35898
delete redundant parameter
SiyunZhao Nov 14, 2025
77de9e9
Merge branch 'dev/search_r1_v02' into dev/support_mask
JiahangXu Nov 18, 2025
5c3b9c6
stage debug scripts
JiahangXu Nov 26, 2025
7b57eed
Merge branch 'main' into dev/search_r1_v02
JiahangXu Nov 28, 2025
11e8589
update logger
JiahangXu Nov 28, 2025
161961d
stage test scripts
JiahangXu Nov 28, 2025
283eb89
update daemon timeout
JiahangXu Nov 28, 2025
bcbc95f
Merge branch 'dev/search_r1_v02' into dev/support_mask
JiahangXu Nov 29, 2025
0837a9d
update unmerged logs
JiahangXu Nov 29, 2025
b5afe14
update trajectory scripts
JiahangXu Nov 29, 2025
6bba3dc
Merge branch 'main' into dev/support_mask
JiahangXu Dec 2, 2025
f8a47d9
update mismatch logs, update scripts
JiahangXu Dec 2, 2025
67dad08
update logs code
JiahangXu Dec 3, 2025
e1dbc3c
update running scripts
JiahangXu Dec 3, 2025
e5be217
update external store
JiahangXu Dec 3, 2025
d458b26
update external store (fix bug)
JiahangXu Dec 4, 2025
951b168
update trajectory-strict and trajectory-tolerant
JiahangXu Dec 4, 2025
5a5d54f
update training scrtpts with new args
JiahangXu Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions agentlightning/verl/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@ actor_rollout_ref:
custom_async_server:
path: pkg://agentlightning.verl.async_server
name: PatchedvLLMServer
trace_aggregator:
mode: transition # transition, trajectory-strict, or trajectory-tolerant
special_token_tolerance: 10 # only supported in trajectory-tolerant mode, suggest to set as n_turns
string_tolerance: 20 # only supported in trajectory-tolerant mode, suggest to set as n_turns * 2
trajectory_max_length: 8192 # supported in two trajectory modes, suggest to set as n_turns * (max_response_length + max_prompt_length)
342 changes: 311 additions & 31 deletions agentlightning/verl/daemon.py

Large diffs are not rendered by default.

25 changes: 23 additions & 2 deletions agentlightning/verl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,21 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True, suffix: str
return metrics


def log_step_for_mismatch_detail(step: int) -> None:
import os
os.makedirs("mismatch_log", exist_ok=True)
with open("mismatch_log/template_mismatch.log", "a+") as f:
print("-" * 10 + f" Step {step}" + "-" * 10, file=f)
with open("mismatch_log/retoken_mismatch.log", "a+") as f:
print("-" * 10 + f" Step {step}" + "-" * 10, file=f)
with open("mismatch_log/others_mismatch.log", "a+") as f:
print("-" * 10 + f" Step {step}" + "-" * 10, file=f)
with open("mismatch_log/response_ids_num_mismatch.log", "a+") as f:
print("-" * 10 + f" Step {step}" + "-" * 10, file=f)
with open("mismatch_log/bad_case_unexpected.log", "a+") as f:
print("-" * 10 + f" Step {step}" + "-" * 10, file=f)


class AgentLightningTrainer(RayPPOTrainer):
"""
Specialized PPO trainer for agent-based reinforcement learning.
Expand Down Expand Up @@ -217,9 +232,12 @@ def _train_step(self, batch_dict: dict) -> dict:
gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses
)
self.agent_mode_daemon.run_until_all_finished()
with _timer("gen_postprocess", timing_raw):
batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch(
max_prompt_length=self.config.data.max_prompt_length,
max_response_length=self.config.data.max_response_length,
max_response_length=self.config.actor_rollout_ref.rollout.trace_aggregator.trajectory_max_length \
if self.config.actor_rollout_ref.rollout.trace_aggregator.mode.startswith("trajectory") else \
self.config.data.max_response_length,
device=gen_batch.batch["fake_ids"].device,
)
metrics.update(agent_metrics)
Expand All @@ -245,7 +263,8 @@ def _train_step(self, batch_dict: dict) -> dict:
# uid is used for algorithm like GRPO, should be aligned to data id
batch.non_tensor_batch["uid"] = batch.non_tensor_batch["data_id_list"]

batch.batch["response_mask"] = compute_response_mask(batch)
if "response_mask" not in batch.batch:
batch.batch["response_mask"] = compute_response_mask(batch)

# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
Expand Down Expand Up @@ -427,6 +446,7 @@ def fit(self):
store=self.store,
llm_proxy=self.llm_proxy,
adapter=self.adapter,
trace_aggregator=self.config.actor_rollout_ref.rollout.trace_aggregator,
)
self.agent_mode_daemon.start()

Expand All @@ -449,6 +469,7 @@ def fit(self):

for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
log_step_for_mismatch_detail(self.global_steps) # log data, only for debug testing
metrics = {}
timing_raw = {}
is_last_step = self.global_steps >= self.total_training_steps
Expand Down
Loading
Loading