diff --git a/roll/distributed/scheduler/generate_scheduler.py b/roll/distributed/scheduler/generate_scheduler.py
index 9e8d3bd3..0483f2d3 100644
--- a/roll/distributed/scheduler/generate_scheduler.py
+++ b/roll/distributed/scheduler/generate_scheduler.py
@@ -171,7 +171,7 @@ def generate(self, data: DataProto, actor_cluster: Union[Any, Cluster], pipeline
def get_available_dp_rank(self):
while True:
- # 负载均衡逻辑,期望各dp 正在处理的条数基本接近
+ # Load balancing logic, expect the number of items being processed by each dp to be roughly similar
sorted_ranks = sorted(
self.load_balance_coordinator.keys(), key=lambda rank: (self.load_balance_coordinator[rank], rank)
)
@@ -205,26 +205,26 @@ def generate_opt_level_1(self, data: DataProto):
)
self.cluster.start_server(data=DataProto(meta_info=data.meta_info), blocking=True)
- # 分发数据至收到target rollout 完成
- # 无限循环,把所有的response发送给dp worker
+ # Distribute data until target rollout completion
+ # Infinite loop, send all responses to dp workers
send_request_count = 0
request_refs = []
data_index_counter = itertools.count()
last_alive_check = time.time()
while not self.is_completed:
- # 探测dp worker是否存活,dp worker的server thread可能由于异常退出,造成hang
+ # Check if dp worker is alive, dp worker's server thread may exit due to exceptions, causing hang
current_time = time.time()
if current_time - last_alive_check >= self.alive_check_interval:
self.cluster.add_request(command=GenerateRequestType.ALIVE_CHECK, data=DataProto())
last_alive_check = current_time
if send_request_count < data.batch.batch_size[0]:
- # 取一个可以发送request的dp worker
+ # Get a dp worker that can send requests
dp_rank = next(self.get_available_dp_rank())
- # 还有数据需要发送, 取需要发送的数据
- # request_id 全局递增,否则vllm/sglang scheduler状态不对
+ # Still have data to send, get the data that needs to be sent
+ # request_id increments globally, otherwise vllm/sglang scheduler state is incorrect
request_id = next(self.request_counter)
data_index = next(data_index_counter)
request_data = collate_fn([self.data[data_index]])
@@ -235,7 +235,7 @@ def generate_opt_level_1(self, data: DataProto):
].item()
self.request_id_2_dp_rank[request_data.meta_info["request_id"]] = dp_rank
self.prompt_id_2_request_ids[prompt_id].add(request_data.meta_info["request_id"])
- # 需要注意上面的调用顺序, report_response中会更新request_id索引dp_rank,所以这里需要最后add request_id
+ # Need to pay attention to the calling order above, report_response will update request_id index dp_rank, so need to add request_id last
request_data.meta_info["response_callback_fn"] = self.response_callback_fn
request_data.meta_info["generation_config"] = data.meta_info["generation_config"]
request_refs.append(
@@ -394,7 +394,7 @@ def set_scheduler(
state: Dict[str, Any] = None,
):
"""
- GenerateScheduler可以由多个实例,不再局限于单例
+ GenerateScheduler can have multiple instances, no longer limited to singleton
"""
self.actor_cluster = actor_cluster
self.reward_clusters = reward_clusters
@@ -459,9 +459,9 @@ def reset_status(self):
def get_batch(self, data: DataProto, batch_size: int) -> DataProto:
"""
- 从dataset里,按给定策略sample batch
- 1. 常规无过滤
- 2. 动态过滤
+ Sample batch from dataset using given strategy
+ 1. Regular without filtering
+ 2. Dynamic filtering
"""
self.batch_size = batch_size
self.reset_status()
@@ -522,7 +522,7 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto:
f"used queries: {query_use_count} query_filter_count: {self.query_filter_count} "
f"response_filter_count: {self.response_filter_count}"
)
- # TODO: 这里 len(collect_data) > rollout_batch_size, 可以尝试动态扩大batch_size
+ # TODO: Here len(collect_data) > rollout_batch_size, can try dynamically expanding batch_size
batch = DataProto.concat(collect_data[: self.batch_size * num_return_sequences])
batch.meta_info["metrics"] = {
f"scheduler/query_filter_count": self.query_filter_count,
@@ -531,7 +531,7 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto:
f"scheduler/query_use_count": query_use_count,
}
- # 统计全部response metrics
+ # Count all response metrics
metrics = {}
for domain, response_batches in self.response_cache.items():
response_batch = DataProto.concat(response_batches[:])
@@ -548,8 +548,8 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto:
@ray.method(concurrency_group="multi_thread")
def report_response(self, data: DataProto):
"""
- 这里需要考虑多线程数据访问
- data 返回可能有多条的
+ Need to consider multi-threaded data access here
+ Data return may have multiple entries
"""
try:
request_id = data.meta_info["request_id"]
@@ -570,15 +570,15 @@ def report_response(self, data: DataProto):
return
# call reward
- # reward worker得能支持单条数据计算, dynamic sampling对需要batch计算reward的需要注意...
- # 多域的时候,llm as judge, 需要单独为reward worker分配gpu
+ # reward worker must support single data calculation, dynamic sampling needs attention for batch reward calculation...
+ # In multi-domain cases, llm as judge, need to allocate gpu separately for reward worker
rewards: DataProto = ray.get(reward_worker.compute_rewards.remote(batch))
batch.union(rewards)
response_buffers: List[DataProto] = []
batch_expanded = [batch[[idx]] for idx in range(output_count)]
- # response_filter, 不太需要response filter
+ # response_filter, don't really need response filter
for batch_item in batch_expanded:
if self.response_filter_fn(batch_item, self.pipeline_config):
response_buffers.append(batch_item)
@@ -706,7 +706,7 @@ def expand_requests(self, data: DataProto):
return target_requests
def check_worker_alive(self, cluster):
- # 探测dp worker是否存活,dp worker的server thread可能由于异常退出,造成hang
+ # Check if dp worker is alive, dp worker's server thread may exit due to exceptions, causing hang
current_time = time.time()
if current_time - self.last_alive_check >= self.alive_check_interval:
cluster.add_request(command=GenerateRequestType.ALIVE_CHECK, data=DataProto())
@@ -727,7 +727,7 @@ def check_send_new_request(self) -> bool:
def get_available_dp_rank(self):
while True:
- # 负载均衡逻辑,期望各dp 正在处理的条数基本接近
+ # Load balancing logic, expect the number of items being processed by each dp to be roughly similar
sorted_ranks = sorted(
self.load_balance_coordinator.keys(), key=lambda rank: (self.load_balance_coordinator[rank], rank)
)
diff --git a/roll/distributed/scheduler/reward_scheduler.py b/roll/distributed/scheduler/reward_scheduler.py
index 120bb78d..8703ecd9 100644
--- a/roll/distributed/scheduler/reward_scheduler.py
+++ b/roll/distributed/scheduler/reward_scheduler.py
@@ -15,14 +15,14 @@
@ray.remote
class RewardScheduler:
"""
- reward 服务化和generate不同, request接口:
- reward scheduler需要解决的是不同域的sample的reward计算问题, 不需要实现request粒度的接口;
- 并且reward计算和vllm不同,vllm可以continue batch,所以可以动态add request, reward不行,
- 直接rpc调用reward_cluster.compute_rewards即可(使用rpc方式调用,可以增加reward的数量,增大并发处理能力)
+ Reward service is different from generation, request interface:
+ Reward scheduler needs to solve the reward calculation problem for samples from different domains, no need to implement request-level interface;
+ And reward calculation is different from vllm, vllm can continue batch, so it can dynamically add requests, reward cannot,
+ directly use rpc to call reward_cluster.compute_rewards (using rpc method, can increase the number of rewards, increase concurrent processing capacity)
- reward scheduler需要解决的问题:
- 按domain路由reward
- dp dispatch 均分/不足dp_size 的限制
+ Problems that reward scheduler needs to solve:
+ Route rewards by domain
+ dp dispatch load balancing/insufficient dp_size limitations
"""
def __init__(self):
@@ -32,13 +32,13 @@ def __init__(self):
def compute_rewards(self, data: DataProto, reward_clusters: Dict[str, Any], pipeline_config) -> DataProto:
"""
- 保序返回rewards
+ Return rewards in order
"""
self.pipeline_config = pipeline_config
self.reward_clusters = reward_clusters
data.batch["prompt_id"] = torch.arange(data.batch.batch_size[0], device=data.batch.device)
- # 按domain group by data
+ # Group data by domain
grouped_data: Dict[str, DataProto] = data.group_by("domain")
domain_rewards_refs: Dict[str, List[ray.ObjectRef]] = defaultdict(list)
@@ -51,8 +51,8 @@ def compute_rewards(self, data: DataProto, reward_clusters: Dict[str, Any], pipe
rewards_list: List[DataProto] = []
for domain, domain_rewards_ref in domain_rewards_refs.items():
- # 各reward的输出schema要求一致
- # reward worker compute_rewards 接口返回结果保序
+ # All rewards require consistent output schema
+ # Reward worker compute_rewards interface returns results in order
if domain not in grouped_data.keys():
continue
domain_rewards: DataProto = DataProto.materialize_concat(data_refs=domain_rewards_ref)
diff --git a/roll/models/model_providers.py b/roll/models/model_providers.py
index 2c147ed5..ed3ffe54 100644
--- a/roll/models/model_providers.py
+++ b/roll/models/model_providers.py
@@ -376,9 +376,9 @@ def default_reward_model_provider(
is_trainable: Optional[bool] = False,
):
"""
- model.forward 遵循TokenClassifierOutput 协议
+ model.forward follows TokenClassifierOutput protocol
class TokenClassifierOutput(ModelOutput):
- logits: torch.FloatTensor # 必须要有
+ logits: torch.FloatTensor # Required
loss: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
diff --git a/roll/pipeline/agentic/agentic_config.py b/roll/pipeline/agentic/agentic_config.py
index 10999450..e49eb140 100644
--- a/roll/pipeline/agentic/agentic_config.py
+++ b/roll/pipeline/agentic/agentic_config.py
@@ -45,7 +45,7 @@ class EnvManagerConfig(WorkerConfig):
def __post_init__(self):
"""
- 根据es config计算world_size
+ Calculate world_size based on es config
"""
self.world_size = self.env_groups * self.group_size
self.env_configs: Optional[Dict[int, Dict]] = None
@@ -266,7 +266,7 @@ def set_max_steps(self, max_steps: int):
self.critic.training_args.per_device_train_batch_size
* self.critic.training_args.gradient_accumulation_steps
)
- # 没有除dp_size,需要在分布式环境初始化后再除
+ # Not divided by dp_size, need to divide after distributed environment initialization
self.actor_train.training_args.max_steps = max_steps * (
self.rollout_batch_size
* self.actor_infer.generating_args.num_return_sequences
diff --git a/roll/pipeline/agentic/agentic_pipeline.py b/roll/pipeline/agentic/agentic_pipeline.py
index 54ea5c4f..3b991fba 100644
--- a/roll/pipeline/agentic/agentic_pipeline.py
+++ b/roll/pipeline/agentic/agentic_pipeline.py
@@ -107,7 +107,7 @@ def __init__(self, pipeline_config: AgenticConfig):
@torch.no_grad()
def run(self):
- # 计算tokens per second 系统吞吐
+ # Calculate tokens per second system throughput
tps_timer = _Timer(window_size=5)
for global_step in range(self.pipeline_config.max_steps):
@@ -191,8 +191,8 @@ def run(self):
metrics.update(reduce_metrics(old_log_probs.meta_info.pop("metrics", {})))
metrics["time/old_log_probs_values"] = cal_old_logpb_timer.last
- # 要按group by处理reward
- # 可以tag(env_type)/traj_group_id(group)/batch(rollout_batch)... group_by计算reward/adv
+ # Need to process rewards by group
+ # Can group by tag(env_type)/traj_group_id(group)/batch(rollout_batch)... to calculate reward/adv
batch.batch["prompt_id"] = torch.arange(batch.batch.batch_size[0], device=batch.batch.device)
with Timer(name="adv", logger=None) as timer:
grouping = self.pipeline_config.reward_normalization.grouping
@@ -228,7 +228,7 @@ def run(self):
batch = DataProto.concat(batch_list)
batch.reorder(indices=torch.argsort(batch.batch["prompt_id"]))
batch.pop("prompt_id")
- # advantage是全局batch计算,还是group内计算?
+ # Is advantage calculated globally across batch or within groups?
batch = compute_advantage(
data=batch,
gamma=self.pipeline_config.gamma,
@@ -314,8 +314,8 @@ def run(self):
def compute_data_metrics(batch):
- # token_level_scores 是reward model给每个token的打分,可能经过了norm/clip
- # score 为env的reward,raw value
+ # token_level_scores are scores given by reward model to each token, possibly after norm/clip
+ # score is the environment reward, raw value
sequence_score = batch.batch["scores"].sum(-1)
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
advantages = batch.batch["advantages"]
diff --git a/roll/pipeline/agentic/environment_worker.py b/roll/pipeline/agentic/environment_worker.py
index 80493b35..6576eb11 100644
--- a/roll/pipeline/agentic/environment_worker.py
+++ b/roll/pipeline/agentic/environment_worker.py
@@ -51,7 +51,7 @@ def get_masks_and_scores(
):
"""
input_ids: shape (bsz, seq_len)
- all_scores: list[list[float], 存储每个env每轮的reward
+ all_scores: list[list[float]], stores reward for each env per turn
Get loss mask that only learns between <|im_start|>assistant and <|im_end|>. Currently only supports qwen.
NOTE: important! This assumes that the input_ids starts with system and then user & assistant in alternative ways
NOTE: important! input_ids is left pad
@@ -65,7 +65,7 @@ def get_masks_and_scores(
non_prompt_mask = turn_indicators > 2 # learns everything after system prompt + user prompts
# turn text: '<|im_start|>assistant\nRight<|im_end|>'
- # <|im_start|>assistant\n 应该mask掉才对,保留<|im_end|>
+ # <|im_start|>assistant\n should be masked, keep <|im_end|>
for idx, scores in enumerate(zip_longest(*all_scores, fillvalue=0)):
turn_indicator = idx * 2 + 3 # 0: pad. 1: system. 2+2n: user. 3+2n: assistant
turn_start_position = (input_ids == turn_start_token) & (turn_indicators == turn_indicator)
@@ -129,19 +129,19 @@ def left_pad_2_right(
class EnvironmentWorker(Worker):
"""
- 1. 一个EnvironmentWorker(进程)持有一个env实例: 执行env.reset, env.step, 管理rollout的状态
- group trajectory表达: group内的init state一致,依赖env_config 中的seed来控制, 一个group内env 对应episode的seed一致
- 不采用持有envs的原因是,envs需要管理一组env的交互,增加描述的复杂性
- 2. 持有infer_cluster ref, 用于async generate
- 3. run_rollout_loop, 持续rollout trajectory, 将done的trajectory回传到output_queue
-
- 承担EnvStateManager的history收集功能
- 一个group内的env reset进度应该一致
-
- TODO: env并行方式后续改成进程+线程并行:目的解决一个env占用一个进程对系统资源的开销
- - 一个EnvironmentWorker持有n个EnvStateManager
- - EnvStateManager管理一个env的rollout loop
- - EnvStateManager.run_rollout_loop,运行在n个线程里
+ 1. One EnvironmentWorker (process) holds an env instance: executes env.reset, env.step, manages rollout state
+ group trajectory representation: init state within group is consistent, controlled by seed in env_config, env within a group corresponds to consistent episode seed
+ reason for not holding envs: envs need to manage interaction of a group of envs, increasing complexity of description
+ 2. Holds infer_cluster ref, used for async generate
+ 3. run_rollout_loop, continuously rollout trajectory, returns done trajectory to output_queue
+
+ Takes on EnvStateManager's history collection functionality
+ Env reset progress within a group should be consistent
+
+ TODO: env parallelism will later be changed to process+thread parallelism: purpose is to solve the overhead of one env occupying one process on system resources
+ - One EnvironmentWorker holds n EnvStateManagers
+ - EnvStateManager manages one env's rollout loop
+ - EnvStateManager.run_rollout_loop, runs in n threads
TODO: GiGPO: https://arxiv.org/abs/2505.10978
"""
@@ -276,7 +276,7 @@ def generate(self, env_output: Dict):
lm_output: DataProto = ray.get(self.generate_scheduler.generate_one_request.remote(data=gen_batch))
if lm_output is not None:
- # 未被abort
+ # not aborted
gen_batch.meta_info.pop("generation_config")
gen_batch.meta_info.pop("response_callback_fn")
lm_input = lm_input.repeat(repeat_times=generation_config["num_return_sequences"])
@@ -286,11 +286,11 @@ def generate(self, env_output: Dict):
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def run_rollout_loop(self, data: DataProto):
"""
- 1. 每次调用run_rollout_loop,
- 会持续的play episode, 直到收到采集完成的command
- 需要重置seed, 确保每个group的seed一致
- episode_id 置0
- seed更新逻辑:
+ 1. Each call to run_rollout_loop,
+ continuously play episode, until receiving collection completion command
+ need to reset seed, ensure seed consistency within each group
+ episode_id set to 0
+ seed update logic:
group_seed = seed + group_seed
episode_seed = group_seed + episode_id
@@ -379,9 +379,9 @@ def get_env_input(self, lm_output: DataProto) -> Dict:
def formulate_rollouts(self):
"""
- 1. 每个env的trajectory 是一个rollout
- 2. 每个rollout 是一个List[Dict]
- 3. 每个Dict 是一个step的信息
+ 1. Each env's trajectory is a rollout
+ 2. Each rollout is a List[Dict]
+ 3. Each Dict is information for one step
"""
llm_input_texts, messages_list = self._format_messages(
env_output=self.rollout_cache, prepare_for_update=True, use_raw_llm_response=False
@@ -436,7 +436,7 @@ def formulate_rollouts(self):
llm_inputs.batch["non_prompt_mask"] = non_prompt_mask
llm_inputs.batch["response_mask"] = non_prompt_mask
if self.pipeline_config.enable_response_mask:
- # 只使用llm的response mask,不包含环境的state
+ # only use llm's response mask, not including environment's state
llm_inputs.batch["response_mask"] = response_mask
first_true_indices = non_prompt_mask.int().argmax(dim=1)
no_true_mask = ~non_prompt_mask.any(dim=1)
@@ -601,7 +601,7 @@ def _format_messages(self, env_output: Dict, prepare_for_update: bool, use_raw_l
)
if "llm_raw_response" in content:
# yali: using the raw response will cause continuous crashes: https://aliyuque.antfin.com/mdl-team/traning/wmne4oyxg4dozwia
- # 改成actions合理吗?
+ # is it reasonable to change to actions?
messages.append(
{
"role": "assistant",
@@ -627,12 +627,12 @@ def _format_messages(self, env_output: Dict, prepare_for_update: bool, use_raw_l
else:
text += "" # force the LLM to answer
- # TODO: 应该没有必要,注意处理mask
+ # TODO: should not be necessary, pay attention to handling mask
text = text.replace("<|im_end|>\n", "<|im_end|>")
return [text], [messages]
def _init_prefix_lookup(self):
- # TODO: 这里并不合理
+ # TODO: this is not reasonable here
prefix_lookup = {}
prefixes = {}
env_config_lookup = {}
diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py
index 94106271..9fca480d 100644
--- a/roll/pipeline/base_worker.py
+++ b/roll/pipeline/base_worker.py
@@ -163,7 +163,7 @@ def generate(self, data: DataProto):
@torch.no_grad()
def start_server(self, data: DataProto):
"""
- 解决dp generate的长尾问题,async+ load balance
+ Solve dp generate long tail problem, async + load balance
"""
global_step = data.meta_info.get("global_step", 0)
is_offload_states = data.meta_info.get("is_offload_states", True)
@@ -235,9 +235,9 @@ def compute_log_probs(self, data: DataProto):
def forward_func_log_probs(self, data: DataProto, output_tensor: torch.Tensor):
"""
- forward func 接口定义:
- data: DataProto, 由forward_step透传
- output_tensor: torch.Tensor, model.forward()的输出Tensor
+ forward func interface definition:
+ data: DataProto, passed through from forward_step
+ output_tensor: torch.Tensor, output tensor from model.forward()
"""
log_probs = self.strategy.op_compute_log_probs(
logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"]
@@ -247,9 +247,9 @@ def forward_func_log_probs(self, data: DataProto, output_tensor: torch.Tensor):
def loss_func(self, data: DataProto, output_tensor: torch.Tensor):
"""
- loss func接口定义:
- data: DataProto, 由train_step透传
- output_tensor: torch.Tensor, model.forward()的输出Tensor
+ loss func interface definition:
+ data: DataProto, passed through from train_step
+ output_tensor: torch.Tensor, output tensor from model.forward()
"""
response_mask = data.batch["response_mask"][:, 1:].long()
@@ -324,7 +324,7 @@ def do_checkpoint(self, global_step):
with Timer("do_checkpoint") as total_timer:
ckpt_id = f"checkpoint-{global_step}"
- # actor train是直接存在save dir目录下的,其他role是存在save_dir/cluster_name下的
+ # actor train is saved directly in save dir directory, other roles are saved in save_dir/cluster_name
save_dir = os.path.join(self.pipeline_config.output_dir, self.worker_name, ckpt_id)
self.logger.info(f"save checkpoint-{global_step} to {save_dir}")
@@ -341,10 +341,10 @@ def do_checkpoint(self, global_step):
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def add_request(self, command, data: DataProto):
"""
- data req meta_info里需要包含:
+ data req meta_info needs to contain:
request_id: str
response_callback_fn: callable
- generation_config, 按request设置
+ generation_config, set by request
"""
if command == GenerateRequestType.ALIVE_CHECK:
if self.thread_server is not None:
@@ -474,9 +474,9 @@ def train_step(self, data: DataProto):
def loss_func(self, data: DataProto, output_tensor: torch.Tensor):
"""
- loss func接口定义:
- data: DataProto, 由train_step透传
- output_tensor: torch.Tensor, model.forward()的输出Tensor
+ loss func interface definition:
+ data: DataProto, passed through from train_step
+ output_tensor: torch.Tensor, output tensor from model.forward()
"""
response_mask = data.batch["response_mask"][:, 1:]
old_values = data.batch["values"]
@@ -534,7 +534,7 @@ def do_checkpoint(self, global_step):
class RewardWorker(Worker):
"""
- Reward Model 使用 AutoModelForSequenceClassification 协议
+ Reward Model uses AutoModelForSequenceClassification protocol
"""
def __init__(self, worker_config: WorkerConfig):
diff --git a/roll/pipeline/rlvr/actor_worker.py b/roll/pipeline/rlvr/actor_worker.py
index 71bd3834..ab328f03 100644
--- a/roll/pipeline/rlvr/actor_worker.py
+++ b/roll/pipeline/rlvr/actor_worker.py
@@ -10,9 +10,9 @@ class ActorWorker(BaseActorWorker):
def loss_func(self, data: DataProto, output_tensor: torch.Tensor):
"""
- loss func接口定义:
- data: DataProto, 由train_step透传
- output_tensor: torch.Tensor, model.forward()的输出Tensor
+ loss func interface definition:
+ data: DataProto, passed through from train_step
+ output_tensor: torch.Tensor, output tensor from model.forward()
"""
response_mask = data.batch["response_mask"][:, 1:].long()
@@ -122,12 +122,12 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor):
def compute_sample_weights(self, data: DataProto, response_mask: torch.Tensor):
"""
- 可以基于难度和长度的样本权重
+ Sample weights can be based on difficulty and length
"""
batch_size = response_mask.shape[0]
sample_weights = torch.ones(batch_size, device=response_mask.device)
- # 1. 基于难度的权重 - 例如:难度越高,权重越大
+ # 1. Difficulty-based weights - example: higher difficulty, higher weight
if self.pipeline_config.difficulty_loss_weight and "difficulty" in data.non_tensor_batch:
try:
difficulty = data.non_tensor_batch["difficulty"]
@@ -139,12 +139,12 @@ def compute_sample_weights(self, data: DataProto, response_mask: torch.Tensor):
difficulty_weights = 0.5 + 1.5 * norm_difficulty
sample_weights = sample_weights * difficulty_weights
except Exception as e:
- self.logger.warning(f"跳过difficulty权重计算:{str(e)}")
+ self.logger.warning(f"Skip difficulty weight calculation: {str(e)}")
- # 2. 基于长度的权重 - 例如:长度越长,权重越小
+ # 2. Length-based weights - example: longer length, smaller weight
response_lengths = response_mask.sum(dim=1).float()
if self.pipeline_config.length_loss_weight:
- # 同样归一化长度到[0.5, 2.0]范围
+ # Similarly normalize length to [0.5, 2.0] range
norm_lengths = (response_lengths - response_lengths.min()) / (
response_lengths.max() - response_lengths.min() + 1e-8
)
diff --git a/roll/pipeline/rlvr/rewards/code_sandbox_reward_worker.py b/roll/pipeline/rlvr/rewards/code_sandbox_reward_worker.py
index 2aadc5cb..c90b9448 100644
--- a/roll/pipeline/rlvr/rewards/code_sandbox_reward_worker.py
+++ b/roll/pipeline/rlvr/rewards/code_sandbox_reward_worker.py
@@ -79,7 +79,7 @@ def __init__(self, sandbox_url: str):
self.logger = get_logger()
def check_format(self, prompt_id: str, text: str):
- # 检查格式是否满足要求:回答中必须包含""和代码片段
+ # Check if format meets requirements: response must contain "" and code blocks
has_think_tag = "" in text
has_code_block = "```" in text
@@ -92,7 +92,7 @@ def check_format(self, prompt_id: str, text: str):
return format
def extract_code_blocks(self, prompt, text: str, case_type: str = "input"):
- """提取代码块"""
+ """Extract code blocks"""
if "<|begin_of_solution|>" in text:
text = text.split("<|begin_of_solution|>")[-1].strip()
if "" in text:
@@ -122,7 +122,7 @@ def extract_code_blocks(self, prompt, text: str, case_type: str = "input"):
return codes, langs, ""
def format_sandbox_test(self, test_code, code_language, case_type, test_cases) -> Optional[List[Dict]]:
- """格式化sandbox测试用例"""
+ """Format sandbox test cases"""
test_cases_final = []
if code_language is None or code_language == "":
# TDO detect programming language
@@ -249,7 +249,7 @@ def sandbox_test(self, prompt_id: str, test_cases: List[Dict]) -> Tuple[List[Dic
return result
def sanbox_result_judge(self, test_cases: List[Dict], sandbox_results: List[Dict]) -> int:
- """判断测试用例通过数量"""
+ """Judge the number of test cases passed"""
pass_test_number = 0
error_types = []
for i, responses in enumerate(sandbox_results):
@@ -304,7 +304,7 @@ def single_code_test(
prompt: str = "",
flag: int = 0,
):
- """单条代码测试"""
+ """Single code test"""
info = {
"global_step": global_step,
"prompt_id": prompt_id,
@@ -320,11 +320,11 @@ def single_code_test(
"error_info": [],
}
- # 判断格式是否满足要求
+ # Check if format meets requirements
format_validation = self.check_format(prompt_id, code)
info["format_validation"] = format_validation
start_time = time.time()
- # 抽取代码片段
+ # Extract code blocks
codes, code_langs, error_info = self.extract_code_blocks(prompt, code, case_type)
if error_info != "" or len(codes) == 0:
info["error_info"] = ["extract_code_blocks error"]
@@ -337,13 +337,13 @@ def single_code_test(
# TDO detect programming language
code_language = "python"
- # 格式化sandbox测试用例
+ # Format sandbox test cases
test_cases, error_info = self.format_sandbox_test(test_code, code_language, case_type, test_cases)
if error_info != "" or test_cases == None:
info["error_info"] = ["format_sandbox_test error"]
return info
- # 调用sandbox测试
+ # Call sandbox testing
succeed_test_cases, responses = self.sandbox_test(prompt_id, test_cases)
if not responses or len(succeed_test_cases) == 0:
info["error_info"] = ["sandbox error"]
@@ -351,7 +351,7 @@ def single_code_test(
info["validation"] = 0
return info
- # 判断sandbox测试结果
+ # Judge sandbox test results
pass_test_number, error_types = self.sanbox_result_judge(succeed_test_cases, responses)
time_duration = time.time() - start_time
@@ -489,8 +489,8 @@ def cal_local_test(prompt_id, response, test_cases, func_name=None, num_process_
class CodeSandboxRewardWorker(Worker):
"""
- (x)Reward Model 使用 AutoModelForSequenceClassification 协议
- 面向code的sandbox 单测的 reward model
+ (x)Reward Model uses AutoModelForSequenceClassification protocol
+ Code-oriented sandbox unit test reward model
"""
def __init__(self, worker_config: RewardConfig):
diff --git a/roll/pipeline/rlvr/rewards/crossthinkqa_rule_reward_worker.py b/roll/pipeline/rlvr/rewards/crossthinkqa_rule_reward_worker.py
index 254849c6..e5e2f92b 100644
--- a/roll/pipeline/rlvr/rewards/crossthinkqa_rule_reward_worker.py
+++ b/roll/pipeline/rlvr/rewards/crossthinkqa_rule_reward_worker.py
@@ -1,4 +1,4 @@
-# 导入必要的库和模块
+# Import necessary libraries and modules
from functools import partial
from typing import Optional, Union, Iterator
import json
@@ -21,7 +21,7 @@
from roll.utils.logging import get_logger
-logger = get_logger() # 获取日志记录器实例
+logger = get_logger() # Get logger instance
def get_response_length_reward(min_len, max_len):
@@ -70,19 +70,19 @@ def repetition_penalty_reward(response, **kwargs) -> float:
def extract_after_last_think(input_string, end_think=""):
"""
- 提取输入字符串中最后一个"end_think"标签之后的内容,
- 并移除结果字符串开头的所有换行符。
+ Extract content after the last "end_think" tag in the input string,
+ and remove all newlines at the beginning of the result string.
Args:
- input_string: 原始字符串。
+ input_string: Original string.
Returns:
- 提取并处理后的字符串。如果未找到"end_think"标签,则返回空字符串。
+ Extracted and processed string. Returns empty string if "end_think" tag not found.
"""
last_index = input_string.rfind(end_think)
if last_index == -1:
- return input_string # 或者根据需要返回 None 或原始字符串
+ return input_string # Or return None or original string as needed
start_pos = last_index + len(end_think)
extracted_part = input_string[start_pos:]
@@ -97,9 +97,9 @@ def crossthinkqa_reward_fn(response, ground_truth, reward_type):
correct_flag = False
# 1. format
- # 找到所有的 \\boxed{} 匹配项
+ # Find all \\boxed{} matches
box_matches = re.findall(r"\\boxed\{([^}]+)\}", response)
- # 如果没有找到 \\boxed{} 则返回 None
+ # If no \\boxed{} found, return None
if not box_matches:
lower_response = response.lower()
last_answer_index = lower_response.rfind("answer is")
@@ -107,7 +107,7 @@ def crossthinkqa_reward_fn(response, ground_truth, reward_type):
extracted_answer = response
else:
extracted_answer = response[last_answer_index + 9 :]
- # 获取最后一个 \\boxed{} 的内容
+ # Get content of the last \\boxed{}
else:
format_flag = True
extracted_answer = box_matches[-1]
@@ -145,8 +145,8 @@ def crossthinkqa_reward_fn(response, ground_truth, reward_type):
class CrossThinkQARuleRewardWorker(Worker):
"""
- 一个示例 Reward Worker,用于执行 ifeval 验证并把每个 func 的结果放到 output.tensors 中。
- 在此示例里,ground_truths的str
+ Example Reward Worker for executing ifeval validation and putting each func result into output.tensors.
+ In this example, ground_truths string
"""
def __init__(self, worker_config: WorkerConfig):
@@ -179,8 +179,8 @@ def compute_rewards(self, data: DataProto):
scores = []
repetition_penalty_rewards = []
response_length_rewards = []
- format_values = [] # 格式正确的value(严格要求有\boxed{})
- correct_values = [] # 答案正确的value(用更宽松的规则提取)
+ format_values = [] # Format-correct values (strictly require \boxed{})
+ correct_values = [] # Answer-correct values (use more lenient extraction rules)
for i, (resp_tokens, ground_truth, tag, prompt) in enumerate(
zip(data.batch["responses"], ground_truths, tags, prompts)
@@ -200,13 +200,13 @@ def compute_rewards(self, data: DataProto):
format_value = 1 if format_flag else 0
correct_value = 1 if correct_flag else 0
- # score应该为0或者1,标志模型回复的对错
+ # score should be 0 or 1, indicating model response correctness
if crossthinkqa_reward > 0:
score = 1.0
else:
score = 0.0
- # 存到 crossthinkqa_rewards
+ # Store to crossthinkqa_rewards
crossthinkqa_rewards.append(crossthinkqa_reward)
scores.append(score)
repetition_penalty_rewards.append(repetition_penalty_reward)
@@ -248,8 +248,8 @@ def compute_rewards(self, data: DataProto):
format_values = torch.tensor(format_values, dtype=torch.float16)
correct_values = torch.tensor(correct_values, dtype=torch.float16)
- # 5) 将这些张量打包进同一个字典
- # TODO: 不同的reward worker的output是否需要统一output,或者有没有自适应的办法,避免在新增监控量时每个worker都需要修改
+ # 5) Pack these tensors into the same dictionary
+ # TODO: Whether different reward workers' outputs need unified output, or if there's an adaptive approach to avoid modifying every worker when adding new monitoring metrics
output_tensors = {
"token_level_rewards": token_level_rewards,
"response_level_rewards": response_level_rewards,
@@ -260,6 +260,6 @@ def compute_rewards(self, data: DataProto):
# "correct_values": correct_values
}
- # 6) 用 DataProto.from_dict(...) 构造返回值
+ # 6) Use DataProto.from_dict(...) to construct return value
output = DataProto.from_dict(tensors=output_tensors)
return output
diff --git a/roll/pipeline/rlvr/rewards/general_val_rule_reward_worker.py b/roll/pipeline/rlvr/rewards/general_val_rule_reward_worker.py
index e0268e46..bbbc6db2 100644
--- a/roll/pipeline/rlvr/rewards/general_val_rule_reward_worker.py
+++ b/roll/pipeline/rlvr/rewards/general_val_rule_reward_worker.py
@@ -1,4 +1,4 @@
-# 导入必要的库和模块
+# Import necessary libraries and modules
from functools import partial
from typing import Optional, Union, Iterator
import json
@@ -21,24 +21,24 @@
from roll.utils.logging import get_logger
-logger = get_logger() # 获取日志记录器实例
+logger = get_logger() # Get logger instance
def extract_after_last_think(input_string, end_think=""):
"""
- 提取输入字符串中最后一个"end_think"标签之后的内容,
- 并移除结果字符串开头的所有换行符。
+ Extract content after the last "end_think" tag in the input string,
+ and remove all newlines at the beginning of the result string.
Args:
- input_string: 原始字符串。
+ input_string: Original string.
Returns:
- 提取并处理后的字符串。如果未找到"end_think"标签,则返回空字符串。
+ Extracted and processed string. Returns empty string if "end_think" tag not found.
"""
last_index = input_string.rfind(end_think)
if last_index == -1:
- return input_string # 或者根据需要返回 None 或原始字符串
+ return input_string # Or return None or original string as needed
start_pos = last_index + len(end_think)
extracted_part = input_string[start_pos:]
@@ -52,9 +52,9 @@ def single_choice_reward(response, ground_truth):
correct_flag = False
# 1. format
- # 找到所有的 \\boxed{} 匹配项
+ # Find all \\boxed{} matches
box_matches = re.findall(r"\\boxed\{([^}]+)\}", response)
- # 如果没有找到 \\boxed{} 则返回 None
+ # If no \\boxed{} found, return None
if not box_matches:
lower_response = response.lower()
last_answer_index = lower_response.rfind("answer is")
@@ -62,7 +62,7 @@ def single_choice_reward(response, ground_truth):
extracted_answer = response
else:
extracted_answer = response[last_answer_index + 9 :]
- # 获取最后一个 \\boxed{} 的内容
+ # Get content of the last \\boxed{}
else:
format_flag = True
extracted_answer = box_matches[-1]
@@ -100,8 +100,8 @@ def single_choice_reward(response, ground_truth):
class GeneralValRuleRewardWorker(Worker):
"""
- 一个示例 Reward Worker,用于执行 ifeval 验证并把每个 func 的结果放到 output.tensors 中。
- 在此示例里,ground_truths的str
+ Example Reward Worker for executing ifeval validation and putting each func result into output.tensors.
+ In this example, ground_truths string
"""
def __init__(self, worker_config: WorkerConfig):
@@ -125,8 +125,8 @@ def compute_rewards(self, data: DataProto):
tags = data.non_tensor_batch["tag"]
scores = []
- format_values = [] # 格式正确的value(严格要求有\boxed{})
- correct_values = [] # 答案正确的value(用更宽松的规则提取)
+ format_values = [] # Format-correct values (strictly require \boxed{})
+ correct_values = [] # Answer-correct values (use more lenient extraction rules)
for i, (resp_tokens, ground_truth, tag, prompt) in enumerate(
zip(data.batch["responses"], ground_truths, tags, prompts)
@@ -141,13 +141,13 @@ def compute_rewards(self, data: DataProto):
extracted_answer, reward, format_flag, correct_flag = single_choice_reward(answer_text, ground_truth)
format_value = 1 if format_flag else 0
correct_value = 1 if correct_flag else 0
- # score应该为0或者1,标志模型回复的对错
+ # score should be 0 or 1, indicating model response correctness
if reward > 0:
score = 1.0
else:
score = 0.0
- # 存到 scores
+ # Store to scores
scores.append(score)
format_values.append(format_value)
correct_values.append(correct_value)
@@ -174,7 +174,7 @@ def compute_rewards(self, data: DataProto):
token_level_rewards = torch.zeros_like(data.batch["responses"], dtype=torch.float16)
response_level_rewards = torch.zeros_like(scores, dtype=torch.float16)
- # 5) 将这些张量打包进同一个字典
+ # 5) Pack these tensors into the same dictionary
output_tensors = {
"scores": scores,
# "format_values": format_values,
@@ -183,6 +183,6 @@ def compute_rewards(self, data: DataProto):
"response_level_rewards": response_level_rewards,
}
- # 6) 用 DataProto.from_dict(...) 构造返回值
+ # 6) Use DataProto.from_dict(...) to construct return value
output = DataProto.from_dict(tensors=output_tensors)
return output
diff --git a/roll/pipeline/rlvr/rewards/ifeval_rule_reward_worker.py b/roll/pipeline/rlvr/rewards/ifeval_rule_reward_worker.py
index f94ad9ea..edd43050 100644
--- a/roll/pipeline/rlvr/rewards/ifeval_rule_reward_worker.py
+++ b/roll/pipeline/rlvr/rewards/ifeval_rule_reward_worker.py
@@ -1,4 +1,4 @@
-# 导入必要的库和模块
+# Import necessary libraries and modules
from functools import partial
from typing import Optional, Union, Iterator
import json
@@ -15,7 +15,7 @@
import itertools, collections
from collections import defaultdict
-# 从已有的 WorkerConfig、Worker、Dispatch 等模块导入
+# Import from existing WorkerConfig, Worker, Dispatch and other modules
from roll.configs.worker_config import WorkerConfig
from roll.distributed.executor.worker import Worker
from roll.distributed.scheduler.decorator import Dispatch, register
@@ -31,40 +31,40 @@
from nltk.corpus import wordnet as wn
from nltk.wsd import lesk
-# 假设 tokenizer 依然来自 default_tokenizer_provider
+# Assume tokenizer still comes from default_tokenizer_provider
from roll.models.model_providers import default_reward_model_provider, default_tokenizer_provider
-# 引入 ifeval 验证函数的字典映射
-# IF_FUNCTIONS_MAP 是题主在上面给出的完整实现中包含的函数映射
+# Import dictionary mapping of ifeval validation functions
+# IF_FUNCTIONS_MAP is the function mapping included in the complete implementation given above
from typing import Union, Dict, List
from roll.utils.logging import get_logger
-logger = get_logger() # 获取日志记录器实例
+logger = get_logger() # Get logger instance
def first_boxed(text: str) -> str | None:
"""
- 提取第一个 \boxed{...} 的内容,支持 boxed 内部再嵌套 {}。
+ Extract content of the first \boxed{...}, supporting nested {} inside boxed.
"""
marker = r"\boxed{"
start = text.find(marker)
if start == -1:
- return "" # 没找到 \boxed{
+ return "" # No \boxed{ found
- i = start + len(marker) # 跳过 '\boxed{'
- depth = 1 # 已进入 1 层 {
+ i = start + len(marker) # Skip '\boxed{'
+ depth = 1 # Already entered 1 level of {
buf = []
- while i < len(text) and depth: # 扫描直到配平
+ while i < len(text) and depth: # Scan until balanced
ch = text[i]
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
- if depth == 0: # 恢复到 0 说明 boxed 完成
+ if depth == 0: # Return to 0 means boxed is complete
break
- if depth: # 只在括号未配平时记录字符
+ if depth: # Only record characters when brackets are not balanced
buf.append(ch)
i += 1
@@ -73,8 +73,8 @@ def first_boxed(text: str) -> str | None:
class timeout:
"""
- 与 MathRewardWorker 示例中类似的超时上下文,用于演示,
- 如果不需要超时,可直接省略。
+ Timeout context similar to MathRewardWorker example, for demonstration,
+ can be omitted directly if timeout is not needed.
"""
def __init__(self, seconds=1, error_message="Timeout"):
@@ -92,97 +92,97 @@ def __exit__(self, type, value, traceback):
signal.alarm(0)
-# 包含关键字:在你的回答中应包含关键字 {keyword1}、{keyword2}
+# Contains keywords: Your answer should include keywords {keyword1}, {keyword2}
def verify_keywords(text, keyword_list):
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
"""
- # 将响应文本转换为小写,以便进行不区分大小写的匹配
+ # Convert response text to lowercase for case-insensitive matching
response_lower = text.lower()
- # 检查响应中是否包含所有关键字(每个关键字也转换为小写进行匹配)
+ # Check if response contains all keywords (each keyword also converted to lowercase for matching)
return all(keyword.lower() in response_lower for keyword in keyword_list)
-# 关键字出现频率:在你的回答中,单词 {word} 应该出现 {N} 次
+# Keyword frequency: In your answer, word {word} should appear {N} times
def verify_keyword_frequency(text, word, N):
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
"""
- # 将文本转换为小写,使搜索不区分大小写
+ # Convert text to lowercase for case-insensitive search
text = text.lower()
word = word.lower()
- # 使用正则表达式匹配单词边界,将文本切分为单词列表
+ # Use regex to match word boundaries and split text into word list
words = re.findall(r"\b\w+\b", text)
- # 统计实际出现次数(精确匹配关键字)
+ # Count actual occurrences (exact match keywords)
actual_count = sum(1 for word in words if word == word)
- # 检查实际出现次数是否等于期望的 N 次
+ # Check if actual occurrence count equals expected N times
constraint_met = actual_count == N
return constraint_met
-# 禁止出现特定单词:回答中不应包含关键字 {forbidden words}
+# Forbidden words: Answer should not contain keywords {forbidden words}
def validate_forbidden_words(text, forbidden_words):
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
"""
- # 将文本转换为小写,进行不区分大小写的匹配
+ # Convert text to lowercase for case-insensitive matching
text_lower = text.lower()
- # 检查每个禁止单词是否出现在文本中
+ # Check if each forbidden word appears in the text
found_words = [word for word in forbidden_words if word.lower() in text_lower]
- # 如果没有找到禁止单词,返回 True;否则返回 False
+ # If no forbidden words found, return True; otherwise return False
return len(found_words) == 0
-# 字母出现频率:在你的回答中,字母 {letter} 应该恰好出现 {N} 次
+# Letter frequency: In your answer, letter {letter} should appear exactly {N} times
def verify_letter_frequency(text: str, letter: str, N: int) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
"""
if len(letter) != 1:
- raise ValueError("字母参数必须为单个字符")
+ raise ValueError("Letter parameter must be a single character")
actual_count = text.count(letter)
return actual_count == N
-# 回答语言约束:你的整个回答应当使用 {language},不允许包含其他语言内容
+# Response language constraint: Your entire answer should use {language}, no other language content allowed
def validate_response_language(text, language):
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
"""
from langdetect import detect
- # 检测文本的语言
+ # Detect text language
detected_language = detect(text)
- # 检查检测到的语言是否与预期语言相符
+ # Check if detected language matches expected language
return detected_language == language
-# 段落数量:回答中应包含 {N} 个段落,段落之间使用 markdown 分隔符 "* * *" 隔开
+# Paragraph count: Answer should contain {N} paragraphs, separated by markdown separator "* * *"
def verify_paragraph_count(text: str, N: int) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
"""
def clean_text(text: str) -> str:
- """移除多余空白字符,并规范换行符"""
+ """Remove excess whitespace and normalize line breaks"""
return "\n".join(line.strip() for line in text.splitlines()).strip()
- # 清理输入文本
+ # Clean input text
text = clean_text(text)
- # 依据 markdown 分隔符分割文本,每个分隔符会创建 n+1 个段落
+ # Split text by markdown separator, each separator creates n+1 paragraphs
paragraphs = text.split("* * *")
actual_count = len(paragraphs)
- # 验证每个分割结果中是否包含非空内容
+ # Verify each split result contains non-empty content
valid_paragraphs = [p.strip() for p in paragraphs if p.strip()]
if len(valid_paragraphs) != actual_count:
return False
@@ -190,16 +190,16 @@ def clean_text(text: str) -> str:
return actual_count == N
-# 单词数量约束:回答中的单词数应至少/大约/最多达到 {N} 个
+# Word count constraint: Answer word count should be at least/around/at most {N}
def validate_word_constraint(text: str, N: int, quantifier: str) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
"""
- # 清除多余空白字符并拆分文本为单词列表
+ # Remove excess whitespace and split text into word list
words = text.strip().split()
actual_count = len(words)
- # 定义 "around" 约束的容错范围(目标单词数的 ±10%,至少 1 个单词)
+ # Define tolerance range for "around" constraint (±10% of target word count, at least 1 word)
tolerance = max(round(N * 0.1), 1)
if quantifier == "at least":
@@ -212,18 +212,18 @@ def validate_word_constraint(text: str, N: int, quantifier: str) -> bool:
return False
-# 句子数量约束:回答中应包含至少/大约/最多 {N} 个句子
+# Sentence count constraint: Answer should contain at least/around/at most {N} sentences
def verify_sentence_constraint(text: str, N: int, quantifier: str) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
"""
- # 使用正则表达式根据句号或问号后的空格拆分文本为句子列表
+ # Use regex to split text into sentence list based on periods or question marks followed by spaces
sentences = re.split(r"(?= N
elif quantifier == "around":
@@ -234,61 +234,61 @@ def verify_sentence_constraint(text: str, N: int, quantifier: str) -> bool:
return False
-# 段落数量及指定段落首词约束:回答中应包含 {N} 个段落,段落之间仅以两个换行符分隔,第 {i} 个段落必须以 {first word} 开头
+# Paragraph count and specific paragraph first word constraint: Answer should contain {N} paragraphs separated only by two newlines, paragraph {i} must start with {first word}
def validate_paragraphs(text, N, first_word, i):
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
"""
- # 根据两个换行符分割文本为段落
+ # Split text into paragraphs by two newlines
paragraphs = text.split("\n\n")
- # 检查段落总数是否符合要求
+ # Check if total paragraph count meets requirements
if len(paragraphs) != N:
return False
- # 检查第 i 个段落的开头是否为指定单词
+ # Check if paragraph i starts with specified word
if paragraphs[i - 1].strip().startswith(first_word):
return True
return False
-# 附言验证:请在回答末尾明确添加以 {postscript marker} 开头的附言
+# Postscript validation: Please clearly add a postscript starting with {postscript marker} at the end of your answer
def verify_postscript(text, postscript_marker):
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
"""
- # 检查文本中是否包含附言标记
+ # Check if text contains postscript marker
if postscript_marker in text:
- # 获取标记的索引位置
+ # Get marker index position
marker_index = text.find(postscript_marker)
- # 检查标记附近是否还有其它内容
+ # Check if there's other content near the marker
remaining_text = text[marker_index:].strip()
- # 验证附言不只是标记本身而已
+ # Verify postscript is not just the marker itself
return len(remaining_text) > len(postscript_marker)
return False
-# 占位符验证:回答中应至少包含 {N} 个用方括号表示的占位符,例如 [address]
+# Placeholder validation: Answer should contain at least {N} placeholders in square brackets, e.g. [address]
def validate_placeholders(text: str, N: int) -> tuple[bool, List[str]]:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
"""
- # 使用正则表达式查找所有位于方括号内的内容
+ # Use regex to find all content within square brackets
pattern = r"\[(.*?)\]"
placeholders = re.findall(pattern, text)
- # 检查是否至少找到了 N 个占位符
+ # Check if at least N placeholders were found
has_enough = len(placeholders) >= N
return has_enough
-# 项目符号验证:回答必须包含恰好 {N} 个项目符号点。请使用 markdown 格式的项目点,例如:* 这是一个点。
+# Bullet point validation: Answer must contain exactly {N} bullet points. Use markdown format bullet points, e.g.: * This is a point.
def verify_bullet_points(text: str, N: int) -> tuple[bool, str]:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
"""
- # 按行拆分文本,并统计以 * 或 - 开头的行
+ # Split text by lines and count lines starting with * or -
lines = text.split("\n")
bullet_points = [line.strip() for line in lines if line.strip().startswith(("*", "-"))]
actual_count = len(bullet_points)
@@ -299,7 +299,7 @@ def verify_bullet_points(text: str, N: int) -> tuple[bool, str]:
return False
-# 标题验证:回答中必须包含一个标题,用双尖括号包裹,例如 <>
+# Title validation: Answer must contain a title wrapped in double angle brackets, e.g. <>
def validate_title(text: str) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
@@ -313,7 +313,7 @@ def validate_title(text: str) -> bool:
return False
-# 选择题验证:回答内容必须为以下选项之一:{options}
+# Multiple choice validation: Answer content must be one of the following options: {options}
def validate_choice(text: str, options: list) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
@@ -324,7 +324,7 @@ def validate_choice(text: str, options: list) -> bool:
return False
-# 高亮区域数量验证:回答中必须至少高亮 {N} 个区域,使用 markdown 格式,比如 *highlighted section*
+# Highlighted section count validation: Answer must highlight at least {N} sections using markdown format, e.g. *highlighted section*
def validate_highlighted_sections(text: str, N: int) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
@@ -338,13 +338,13 @@ def validate_highlighted_sections(text: str, N: int) -> bool:
return False
-# 多区块验证:回答中必须包含 {N} 个区块,每个区块的开始都应以 {section splitter} 开头
+# Multi-section validation: Answer must contain {N} sections, each section should start with {section splitter}
def validate_sections(text: str, N: int, section_splitter: str) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
"""
sections = text.split(section_splitter)
- # 第一个区块可能不以分割符开头,因此需要做调整
+ # First section may not start with splitter, so adjustment needed
if sections[0] == "":
sections.pop(0)
if len(sections) == N:
@@ -353,7 +353,7 @@ def validate_sections(text: str, N: int, section_splitter: str) -> bool:
return False
-# JSON 格式验证:整个输出必须使用 JSON 格式包裹
+# JSON format validation: Entire output must be wrapped in JSON format
def validate_json_format(text: str) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
@@ -365,7 +365,7 @@ def validate_json_format(text: str) -> bool:
return True
-# 重复提示验证:首先重复用户的请求内容不做更改,然后再给出你的回答(重复内容不应包含其他额外信息)
+# Repeat prompt validation: First repeat user's request content unchanged, then give your answer (repeated content should not contain other additional information)
def validate_repeat_prompt(text: str, original_prompt: str) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
@@ -376,7 +376,7 @@ def validate_repeat_prompt(text: str, original_prompt: str) -> bool:
return False
-# 两种回答验证:提供两种不同的回答,两个回答之间仅用六个星号 "******" 分隔开
+# Two responses validation: Provide two different answers separated only by six asterisks "******"
def validate_two_responses(text: str) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
@@ -390,7 +390,7 @@ def validate_two_responses(text: str) -> bool:
return False
-# 全部大写:整个回答必须全部使用英文大写字母
+# All uppercase: Entire answer must use English uppercase letters
def validate_uppercase(text: str) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
@@ -401,7 +401,7 @@ def validate_uppercase(text: str) -> bool:
return False
-# 全部小写:整个回答必须全部使用英文小写字母,不允许有大写字母
+# All lowercase: Entire answer must use English lowercase letters, no uppercase allowed
def validate_lowercase(text: str) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
@@ -412,7 +412,7 @@ def validate_lowercase(text: str) -> bool:
return False
-# 全大写单词出现频率验证:在回答中,全大写单词的出现次数应满足至少/大约/最多 {N} 次
+# All-caps word frequency validation: In the answer, all-caps words should appear at least/around/at most {N} times
def validate_frequency_capital_words(text: str, N: int, quantifier: str) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
@@ -428,7 +428,7 @@ def validate_frequency_capital_words(text: str, N: int, quantifier: str) -> bool
return False
-# 结束语验证:回答最后必须以确切的短语 {end phrase} 结束,且该短语后面不允许有其他内容
+# End phrase validation: Answer must end with exact phrase {end phrase}, with no other content after it
def validate_end(text: str, end_phrase: str) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
@@ -439,7 +439,7 @@ def validate_end(text: str, end_phrase: str) -> bool:
return False
-# 引号包装验证:整个回答必须用双引号包裹起来
+# Quotation wrapping validation: Entire answer must be wrapped in double quotes
def validate_quotation(text: str) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
@@ -450,7 +450,7 @@ def validate_quotation(text: str) -> bool:
return False
-# 禁用逗号:整个回答中不允许出现任何逗号
+# Comma ban: No commas allowed in entire answer
def validate_no_commas(text: str) -> bool:
"""
Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py
@@ -463,23 +463,23 @@ def validate_no_commas(text: str) -> bool:
def call_ifeval_function(func, text: str, constraint_dict: dict):
"""
- 1) 获取func的函数签名
- 2) 只保留与签名匹配且非None的参数
- 3) 调用func(text, **filtered_args)
+ 1) Get function signature of func
+ 2) Only keep parameters that match signature and are not None
+ 3) Call func(text, **filtered_args)
"""
- # 1) 获取函数签名
+ # 1) Get function signature
sig = inspect.signature(func)
- valid_params = set(sig.parameters.keys()) # 该函数形参名的集合
+ valid_params = set(sig.parameters.keys()) # Set of function parameter names
- # 2) 过滤掉 constraint_dict 中的无关字段和 None 值
- # (如果一个函数的参数刚好是 None 值也是合法,就保留;否则你可以额外判断)
+ # 2) Filter out irrelevant fields and None values in constraint_dict
+ # (If a function parameter is None value but still valid, keep it; otherwise you can add extra judgment)
filtered_args = {}
for k, v in constraint_dict.items():
- if k in valid_params: # 形参里确实有这个字段
- # 如果你想彻底丢弃 None,也可以加上: if v is not None:
+ if k in valid_params: # Parameter list actually has this field
+ # If you want to completely discard None, you can add: if v is not None:
filtered_args[k] = v
- # 3) 调用函数
+ # 3) Call function
return func(text, **filtered_args)
@@ -487,35 +487,35 @@ def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
"""
Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
"""
- # 如果 max_penalty 是正的,这里直接抛出错误,说明要用负值来做惩罚
+ # If max_penalty is positive, throw error directly, indicating negative values should be used for penalty
if max_penalty > 0:
raise ValueError(f"max_penalty {max_penalty} should not be positive")
- # 内部函数 zipngram,用于切分文本为 ngram
+ # Internal function zipngram for splitting text into ngrams
def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)])
- # repetition_penalty_reward 函数用于计算在给定 response 中,n-gram 的重复程度
+ # repetition_penalty_reward function calculates n-gram repetition degree in given response
def repetition_penalty_reward(response, **kwargs) -> float:
"""
ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
"""
- # 如果回复为空或不足 ngram 大小,则直接返回 0
+ # If response is empty or insufficient ngram size, return 0 directly
if response == "" or len(response.split()) < ngram_size:
return 0.0
- # 遍历所有 ngram,统计 unique ngram 和 total ngram 的数量
+ # Iterate all ngrams, count unique ngram and total ngram quantities
ngrams = set()
total = 0
for ng in zipngram(response, ngram_size):
ngrams.add(ng)
total += 1
- # scaling = 1 - (不重复的 ngram / 总的 ngram 数量)
- # 不重复的越少(重复越多)scaling 越大
+ # scaling = 1 - (non-repeated ngrams / total ngram count)
+ # The fewer non-repeated (more repeated), the larger scaling
scaling = 1 - len(ngrams) / total
- # reward 是 scaling 乘以 max_penalty
+ # reward is scaling multiplied by max_penalty
reward = scaling * max_penalty
return reward
@@ -524,30 +524,30 @@ def repetition_penalty_reward(response, **kwargs) -> float:
def extract_after_last_think(input_string, end_think=""):
"""
- 提取输入字符串中最后一个 '' 标签之后的内容,
- 并移除结果字符串开头的所有换行符。
+ Extract content after the last '' tag in the input string,
+ and remove all newlines at the beginning of the result string.
Args:
- input_string: 原始字符串。
+ input_string: Original string.
Returns:
- 提取并处理后的字符串。如果未找到 '' 标签,则返回空字符串。
+ Extracted and processed string. Returns empty string if '' tag not found.
"""
- # 查找最后一个 end_think 的起始位置
+ # Find starting position of last end_think
last_index = input_string.rfind(end_think)
- # 如果没有找到 end_think
+ # If end_think not found
if last_index == -1:
# return ""
- return input_string # 或者根据需要返回 None 或原始字符串
+ return input_string # Or return None or original string as needed
- # 计算 end_think 结束后的位置
+ # Calculate position after end_think ends
start_pos = last_index + len(end_think)
- # 提取 end_think 之后的部分
+ # Extract part after end_think
extracted_part = input_string[start_pos:]
- # 移除开头的所有换行符 '\n'
+ # Remove all leading newlines '\n'
cleaned_part = extracted_part.lstrip("\n")
return cleaned_part
@@ -555,8 +555,8 @@ def extract_after_last_think(input_string, end_think=""):
class GeneralRuleRewardWorker(Worker):
"""
- 一个示例 Reward Worker,用于执行 ifeval 验证并把每个 func 的结果放到 output.tensors 中。
- 在此示例里,ground_truths的str
+ Example Reward Worker for executing ifeval validation and putting each func result into output.tensors.
+ In this example, ground_truths string
"""
def __init__(self, worker_config: WorkerConfig):
@@ -577,20 +577,20 @@ def initialize(self, pipeline_config):
@register(dispatch_mode=Dispatch.DP_MP_COMPUTE)
def compute_rewards(self, data: DataProto):
"""
- 仅调用 data.non_tensor_batch['ground_truth'] 中的 “func_name”,
- 并将其结果作为单一的 response-level 奖励返回。
+ Only call "func_name" in data.non_tensor_batch['ground_truth'],
+ and return its result as single response-level reward.
"""
- # 1) 解码回复文本
+ # 1) Decode response text
response_text_list = self.tokenizer.batch_decode(data.batch["responses"], skip_special_tokens=False)
batch_size = len(response_text_list)
- # 2) 读取 ground_truth(其中是一串 JSON,包含 func_name 等参数)
+ # 2) Read ground_truth (which is a JSON string containing func_name and other parameters)
prompts = data.non_tensor_batch["prompt"]
ground_truths = data.non_tensor_batch["ground_truth"]
tags = data.non_tensor_batch["tag"]
- # 3) 准备一个列表存放验证结果
+ # 3) Prepare a list to store validation results
results = [0.0] * batch_size
repetition_penalty_rewards = []
response_length_rewards = []
@@ -598,35 +598,35 @@ def compute_rewards(self, data: DataProto):
for i, (resp_tokens, ground_truth, tag, prompt) in enumerate(
zip(data.batch["responses"], ground_truths, tags, prompts)
):
- # 解码当前条目
+ # Decode current entry
resp_text = self.tokenizer.decode(resp_tokens, skip_special_tokens=False)
resp_text1 = resp_text.replace("<|endoftext|>", "").replace("", "").replace("<|im_end|>", "")
resp_text = extract_after_last_think(resp_text1)
# logger.info(f"extract_after_last_think(resp_text): {resp_text}")
if tag == "ifeval":
- # 解析 ground_truth (JSON) 得到约束信息
+ # Parse ground_truth (JSON) to get constraint information
if isinstance(ground_truth, str):
constraint_dict = json.loads(ground_truth)
else:
- constraint_dict = ground_truth # 如果已经是 dict,就直接用
+ constraint_dict = ground_truth # If already dict, use directly
- # 从约束中取出 func_name
+ # Extract func_name from constraints
func_name = constraint_dict.get("func_name", None)
if not func_name or func_name not in IF_FUNCTIONS_MAP:
self.logger.warning("constraint missing func_name")
- # 如果无 func_name 或没找到对应函数
- # 那么这里我们将结果记为 0.0(也可做别的处理)
+ # If no func_name or corresponding function not found
+ # Then we record result as 0.0 (can do other processing)
results[i] = 0.0
continue
- # 移除 func_name,其它参数传给函数
+ # Remove func_name, pass other parameters to function
constraint_dict.pop("func_name")
func = IF_FUNCTIONS_MAP[func_name]
# print(f"Running function {func_name} with Response text: {resp_text}")
# print(f"Response text: {resp_text}")
- # 调用函数进行验证
+ # Call function for validation
try:
result = call_ifeval_function(func, resp_text, constraint_dict)
except Exception as e:
@@ -635,7 +635,7 @@ def compute_rewards(self, data: DataProto):
else:
self.logger.warning(f"Unknown tag: {tag}")
- # 将结果转为 float: bool -> (1.0/0.0), 数值 -> float(...), 其他结构 -> bool(...)
+ # Convert result to float: bool -> (1.0/0.0), numeric -> float(...), other structures -> bool(...)
if isinstance(result, bool):
val = 1.0 if result else 0.0
elif isinstance(result, (int, float)):
@@ -643,27 +643,27 @@ def compute_rewards(self, data: DataProto):
else:
val = 1.0 if result else 0.0
- # 存到 results
+ # Store to results
results[i] = val
repetition_penalty_rewards.append(self.repetition_penalty_reward_fn(resp_text1))
- # 4) 准备输出张量:
- # - token_level_rewards:形状与 responses 相同、全 0
- # - response_level_rewards:即 results
- # - scores:可与 response_level_rewards 相同(用于统计/日志)
+ # 4) Prepare output tensors:
+ # - token_level_rewards: same shape as responses, all 0
+ # - response_level_rewards: i.e. results
+ # - scores: can be same as response_level_rewards (for statistics/logging)
token_level_rewards = torch.zeros_like(data.batch["responses"], dtype=torch.float16)
scores = torch.tensor(results, dtype=torch.float16)
repetition_penalty_rewards = torch.tensor(repetition_penalty_rewards, dtype=torch.float16)
response_level_rewards = scores + repetition_penalty_rewards
- # 5) 将这些张量打包进同一个字典
+ # 5) Pack these tensors into the same dictionary
output_tensors = {
"token_level_rewards": token_level_rewards,
"response_level_rewards": response_level_rewards,
"scores": scores
}
- # 6) 用 DataProto.from_dict(...) 构造返回值
+ # 6) Use DataProto.from_dict(...) to construct return value
output = DataProto.from_dict(tensors=output_tensors)
return output
diff --git a/roll/pipeline/rlvr/rewards/llm_judge_reward_worker.py b/roll/pipeline/rlvr/rewards/llm_judge_reward_worker.py
index e31ec774..fa8eaf99 100644
--- a/roll/pipeline/rlvr/rewards/llm_judge_reward_worker.py
+++ b/roll/pipeline/rlvr/rewards/llm_judge_reward_worker.py
@@ -34,7 +34,7 @@ def __init__(self, worker_config: WorkerConfig):
self.tokenizer = None
self.strategy: Optional[Union[InferenceStrategy, TrainStrategy]] = None
- # LLM judge相关配置
+ # LLM judge related configuration
self.judge_prompt = self.worker_config.judge_prompt if hasattr(self.worker_config, "judge_prompt") else None
self.judge_prompt = prompt_maps[self.judge_prompt]
self.judge_model_type = (
diff --git a/roll/pipeline/rlvr/rewards/math_rule_reward_worker.py b/roll/pipeline/rlvr/rewards/math_rule_reward_worker.py
index ca299f08..7db81236 100644
--- a/roll/pipeline/rlvr/rewards/math_rule_reward_worker.py
+++ b/roll/pipeline/rlvr/rewards/math_rule_reward_worker.py
@@ -48,8 +48,8 @@ def _hf_verify_math_sample(response, answer, output_queue):
def hf_verify_math_sample(answer_a, answer_b, timeout_sec=5.0):
"""
- 在多进程中调用 hf math verify,
- 以在超时时间内完不成时返回 False.
+ Call hf math verify in multiprocessing,
+ return False when timeout occurs.
"""
output_queue = multiprocessing.Queue()
@@ -58,7 +58,7 @@ def hf_verify_math_sample(answer_a, answer_b, timeout_sec=5.0):
p.join(timeout_sec)
if p.is_alive():
- # 超时 -> 杀掉子进程, 返回 False
+ # Timeout -> kill subprocess, return False
p.terminate()
p.join()
return False, "", ""
@@ -118,8 +118,8 @@ def format_reward_fn(text: str, pattern: Optional[str] = r"^.*?.*
class MathRuleRewardWorker(Worker):
"""
- (x)Reward Model 使用 AutoModelForSequenceClassification 协议
- 面向math的rule reward model
+ (x)Reward Model uses AutoModelForSequenceClassification protocol
+ Math-oriented rule reward model
"""
def __init__(self, worker_config: WorkerConfig):
diff --git a/roll/pipeline/rlvr/rlvr_config.py b/roll/pipeline/rlvr/rlvr_config.py
index 17cffec8..0da46e81 100644
--- a/roll/pipeline/rlvr/rlvr_config.py
+++ b/roll/pipeline/rlvr/rlvr_config.py
@@ -282,7 +282,7 @@ def set_max_steps(self, max_steps: int):
self.critic.training_args.per_device_train_batch_size
* self.critic.training_args.gradient_accumulation_steps
)
- # 没有除dp_size,需要在分布式环境初始化后再除
+ # Not divided by dp_size, need to divide after distributed environment initialization
self.actor_train.training_args.max_steps = max_steps * (
self.rollout_batch_size
* self.actor_infer.generating_args.num_return_sequences
diff --git a/roll/pipeline/rlvr/rlvr_pipeline.py b/roll/pipeline/rlvr/rlvr_pipeline.py
index 6011f014..eae67ff7 100644
--- a/roll/pipeline/rlvr/rlvr_pipeline.py
+++ b/roll/pipeline/rlvr/rlvr_pipeline.py
@@ -37,7 +37,7 @@
def preprocess_dataset(dataset, prompt_len, encode_function, num_proc):
- # 处理数据
+ # Process data
print(f"Begin : {dataset}")
dataset = dataset.map(
encode_function,
@@ -46,7 +46,7 @@ def preprocess_dataset(dataset, prompt_len, encode_function, num_proc):
desc="Encoding dataset",
load_from_cache_file=False,
)
- # 过滤cutoff
+ # Filter cutoff
dataset = dataset.filter(
lambda data_i: 5 < len(data_i["input_ids"]) <= prompt_len,
num_proc=num_proc,
@@ -83,7 +83,7 @@ def update_dataset_domain(tag_2_domain: Dict[str, set[str]], row):
def query_filter_fn(data_list: List[DataProto], config: RLVRConfig) -> bool:
"""
- 各domain的过滤规则可以自定义
+ Custom filtering rules for each domain
"""
response_level_rewards = [data.batch["response_level_rewards"] for data in data_list]
if len(response_level_rewards) == 1:
@@ -127,7 +127,7 @@ def __init__(self, pipeline_config: RLVRConfig):
val_dataset_paths = self.pipeline_config.validation.data_args.file_name
self.val_dataset = datasets.load_dataset("json", data_files=val_dataset_paths)["train"]
- # 加上format,然后转ids的func
+ # Add format, then convert to ids function
template_name = (
self.pipeline_config.global_template
if self.pipeline_config.global_template
@@ -311,9 +311,9 @@ def __init__(self, pipeline_config: RLVRConfig):
@torch.no_grad()
def run(self):
- # 计算tokens per second 系统吞吐
+ # Calculate tokens per second system throughput
- # 创建一个专门管理监控指标的类
+ # Create a specialized class for managing monitoring metrics
metrics_mgr = MetricsManager()
tps_timer = _Timer(window_size=5)
@@ -330,7 +330,7 @@ def run(self):
metrics_mgr.clear_metrics()
with tps_timer, Timer(name="step_total", logger=None) as step_total_timer:
- # 先model update,resume时不需要保存infer cluster的状态
+ # First model update, no need to save infer cluster state during resume
if self.pipeline_config.adv_estimator == "gae":
self.critic.offload_states(blocking=True)
self.actor_train.offload_states(blocking=True)
@@ -349,7 +349,7 @@ def run(self):
batch: DataProto = DataProto()
batch.meta_info = {"global_step": global_step}
- # 要按domain group by生成对应的batch
+ # Generate corresponding batches grouped by domain
with actor_infer_timer, actor_infer_response_timer, Timer(
name="step_generate", logger=None
) as step_generate_timer:
@@ -409,18 +409,18 @@ def run(self):
metrics_mgr.add_reduced_metrics(old_log_probs.meta_info.pop("metrics", {}))
metrics_mgr.add_metric("time/old_log_probs", cal_old_logpb_timer.last)
- # 要按domain group by处理reward
+ # Process rewards grouped by domain
batch.batch["prompt_id"] = torch.arange(batch.batch.batch_size[0], device=batch.batch.device)
batch_grouped: Dict[str, DataProto] = batch.group_by("domain")
batch_list = []
for domain, domain_batch in batch_grouped.items():
- # 1. 处理mask相关策略, 获取sample level mask
+ # 1. Process mask related strategies, get sample level mask
with Timer(name="get_sample_level_mask", logger=None) as get_sample_level_mask_timer:
domain_batch, mask_metrics = get_sample_level_mask(domain_batch, self.pipeline_config)
metrics_mgr.add_metrics(mask_metrics)
metrics_mgr.add_metric("time/get_sample_level_mask", get_sample_level_mask_timer.last)
- # 2. 处理reward相关策略
+ # 2. Process reward related strategies
with Timer(name="reward_postprocess", logger=None) as reward_postprocess_timer:
domain_batch, response_level_metrics = reward_postprocess(
domain_batch, self.pipeline_config, self.running
@@ -428,7 +428,7 @@ def run(self):
metrics_mgr.add_metrics(response_level_metrics)
metrics_mgr.add_metric("time/reward_postprocess", reward_postprocess_timer.last)
- # 3. 计算token level rewards
+ # 3. Calculate token level rewards
with Timer(name="get_token_reward", logger=None) as get_token_reward_timer:
domain_batch, token_level_metrics = compute_token_reward(
domain_batch, self.pipeline_config, self.kl_ctrl
@@ -436,7 +436,7 @@ def run(self):
metrics_mgr.add_metrics(token_level_metrics)
metrics_mgr.add_metric("time/get_token_reward", get_token_reward_timer.last)
- # 4. 计算advantage
+ # 4. Calculate advantage
final_response_mask = domain_batch.batch["final_response_mask"].clone()
with Timer(name="compute_advantage", logger=None) as compute_advantage_timer:
domain_batch = compute_advantage(
diff --git a/roll/third_party/deepspeed/offload_states_patch.py b/roll/third_party/deepspeed/offload_states_patch.py
index 62f32092..de38aa79 100644
--- a/roll/third_party/deepspeed/offload_states_patch.py
+++ b/roll/third_party/deepspeed/offload_states_patch.py
@@ -180,7 +180,7 @@ def stage_1_and_2_offload_states(self: DeepSpeedZeroOptimizer,
# LP param
if needs_offload(OffloadStateTypeEnum.lp_params, include, self.offloaded_states) and not self.bit16_groups_flat[
0].is_cpu:
- # NOTE: 这里只支持offload optimizer 里的参数部分
+ # NOTE: Only supports offloading parameters in optimizer here
if pin_memory:
if not hasattr(self, "lp_params_pin_buffers"):
self.lp_params_pin_buffers = [
@@ -200,7 +200,7 @@ def stage_1_and_2_offload_states(self: DeepSpeedZeroOptimizer,
self.offloaded_states.add(OffloadStateTypeEnum.lp_params)
# LP grad
- # NOTE: 这里好像没有 grad 缓存
+ # NOTE: There seems to be no grad cache here
if needs_offload(OffloadStateTypeEnum.lp_grads, include, self.offloaded_states):
pass
@@ -231,7 +231,7 @@ def stage_1_and_2_offload_states(self: DeepSpeedZeroOptimizer,
offload_adam_states(self.optimizer, device, pin_memory=pin_memory, non_blocking=non_blocking)
self.offloaded_states.add(OffloadStateTypeEnum.optim_states)
- # NOTE: 清理额外引用,hp_mapping里包含了一份对全部flat tensor的引用
+ # NOTE: Clean up extra references, hp_mapping contains a reference to all flat tensors
for group in self.bit16_groups:
for param in group:
param._hp_mapping = None
@@ -286,7 +286,7 @@ def stage_1_and_2_reload_states(self: DeepSpeedZeroOptimizer, include=None, non_
reload_adam_states(self.optimizer, device, non_blocking=non_blocking)
self.offloaded_states.remove(OffloadStateTypeEnum.optim_states)
- # NOTE: 恢复link
+ # NOTE: Restore link
for group in self.bit16_groups:
for param in group:
param._hp_mapping = None
@@ -329,9 +329,9 @@ def stage_3_offload_states(self: DeepSpeedZeroOptimizer_Stage3,
if not hasattr(self, "lp_param_contiguous_pin_buffer"):
self.lp_param_contiguous_pin_buffer = get_accelerator().pin_memory(
torch.empty_like(self.lp_param_buffer, device=device))
- # NOTE: lp_param_buffer保存了由optimizer里取到的参数顺序
- # offload的时候先将 lp_param_buffer.cpu()
- # 然后将tensor.data cp给model 的tensor.data,这一步也会有顺序不一致问题
+ # NOTE: lp_param_buffer stores parameter order retrieved from optimizer
+ # When offloading, first move lp_param_buffer.cpu()
+ # Then copy tensor.data to model's tensor.data, this step may also have order inconsistency issues
self.lp_param_contiguous_pin_buffer.copy_(self.lp_param_buffer, non_blocking=non_blocking)
cpu_buffer = self.lp_param_contiguous_pin_buffer
else:
@@ -359,7 +359,7 @@ def stage_3_offload_states(self: DeepSpeedZeroOptimizer_Stage3,
else:
self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(device)
self.averaged_gradients = {}
- # NOTE: self.__param_id_to_grad_partition里存了一份对grad_partitions_flat_buffer的引用,patch修改需要使用名称修饰
+ # NOTE: self.__param_id_to_grad_partition stores a reference to grad_partitions_flat_buffer, patch modifications need to use name mangling
setattr(self, "_DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition", {})
self.offloaded_states.add(OffloadStateTypeEnum.lp_grads)
@@ -398,8 +398,8 @@ def stage_3_reload_states(self: DeepSpeedZeroOptimizer_Stage3, include=None, non
self.lp_param_buffer.data = cpu_buffer.data.to(device, non_blocking=non_blocking)
self._set_fp16_partitioned_groups_flat()
- # NOTE: 这里遍历的是self.module.parameters(), 而lp_param_buffer里的是fp16 group里取到的,这里参数的顺序不一致
- # 这里[p.ds_tensor for p in self.module.parameters()]需要按self.fp16_groups的顺序reorder一下
+ # NOTE: Here we iterate over self.module.parameters(), while lp_param_buffer contains parameters from fp16 groups, the parameter order is inconsistent here
+ # Here [p.ds_tensor for p in self.module.parameters()] needs to be reordered according to self.fp16_groups order
parameter_partitions: List[Tensor] = [param.ds_tensor for sub_group in self.fp16_groups for param in sub_group]
for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(parameter_partitions):
tensor.data = self.lp_param_buffer.data.narrow(0, offset, tensor_numel)
@@ -450,7 +450,7 @@ def parameter_offload_offload_states(self: DeepSpeedZeRoOffload,
self.offloaded_states = getattr(self, "offloaded_states", set())
if needs_offload(OffloadStateTypeEnum.lp_params, include, self.offloaded_states):
- # NOTE: 这里不会执行了non_trainable_params都在engine里处理了
+ # NOTE: This won't execute since non_trainable_params are all handled in the engine
if not hasattr(self, "trainable_params"):
self.trainable_params = [param.ds_tensor for param in self.module.parameters() if param.requires_grad]
if len(self.trainable_params) == 0:
diff --git a/roll/third_party/megatron/offload_states_patch.py b/roll/third_party/megatron/offload_states_patch.py
index e204be22..0ac50687 100644
--- a/roll/third_party/megatron/offload_states_patch.py
+++ b/roll/third_party/megatron/offload_states_patch.py
@@ -1,11 +1,11 @@
"""
-megatron offload states的实现思路:
+Megatron offload states implementation approach:
offload
-释放megatron.core.distributed.distributed_data_parallel.DistributedDataParallel中的buffer
-offload optimizer中的main_weights, main_weights.to('cpu'),使用flat tensor
-offload optimizer states, to('cpu')
-offload model weights, to('cpu'), 使用flat tensor;释放shard_float16_groups和shard_fp32_groups
+Release buffers in megatron.core.distributed.distributed_data_parallel.DistributedDataParallel
+Offload main_weights in optimizer, main_weights.to('cpu'), using flat tensor
+Offload optimizer states, to('cpu')
+Offload model weights, to('cpu'), using flat tensor; release shard_float16_groups and shard_fp32_groups
reload
@@ -221,10 +221,10 @@ def move_ddp_model_params_tensor_to_device(optimizer: DistributedOptimizer,
# Clone model -> main.
shard_model_param = model_param.detach().view(-1)[param_range.start: param_range.end]
- # shard_float16_groups 不属于optimizer state的key,可以直接替换param
- # 这种方式: optimizer.shard_float16_groups[
+ # shard_float16_groups is not an optimizer state key, can directly replace param
+ # This approach: optimizer.shard_float16_groups[
# group_index][len(shard_float16_params_this_group)].data = shard_model_param
- # 不能实现显存释放,定位到是model_param.detach()的影响,下面的fp32能正常释放
+ # Cannot achieve memory release, identified as the effect of model_param.detach(), fp32 below can be released normally
optimizer.shard_float16_groups[group_index][
len(shard_float16_params_this_group)] = shard_model_param
shard_float16_params_this_group.append(shard_model_param)
@@ -252,7 +252,7 @@ def move_grad_data_to_device(optimizer,
# else:
# buffer.grad_data.data = buffer.grad_data.data.to(device, non_blocking=non_blocking)
- # 释放grad, 节省cpu memory
+ # Release grad, save CPU memory
if device == torch.device('cpu'):
buffer.grad_data.data = torch.tensor(1, dtype=buffer.grad_data.data.dtype, device=device, pin_memory=pin_memory)
for param in buffer.params[::-1]:
@@ -438,7 +438,7 @@ def offload_megatron_no_grad_module(model_chunks: List[Union[DistributedDataPara
non_blocking: bool = False
):
"""
- 需要offload一下 grad=False的参数
+ Need to offload parameters with grad=False
"""
device = torch.device('cpu')
diff --git a/roll/third_party/sglang/v043post4_patch/async_engine.py b/roll/third_party/sglang/v043post4_patch/async_engine.py
index 096b069e..e28de453 100644
--- a/roll/third_party/sglang/v043post4_patch/async_engine.py
+++ b/roll/third_party/sglang/v043post4_patch/async_engine.py
@@ -12,7 +12,7 @@ class SglangInputType(enum.Enum):
ABORT = enum.auto()
def list_endswith(lst, suffix):
- # 检查 lst 是否以 suffix 结尾
+ # Check if lst ends with suffix
return lst[-len(suffix):] == suffix if len(suffix) <= len(lst) else False
def trim_overlap_tokens(existing_tokens, new_chunk_tokens):
@@ -28,7 +28,7 @@ def trim_overlap_tokens(existing_tokens, new_chunk_tokens):
return new_chunk_tokens[max_overlap:]
-# 用于存放所有abort_rid_set
+# Used to store all abort_rid_set
abort_rid_set = set()
abort_lock = asyncio.Lock()
@@ -38,7 +38,7 @@ async def producer(thread_queue, asyncio_queue):
while True:
if not thread_queue.empty():
data = thread_queue.get()
- # 收到结束标记
+ # Received end marker
if data is None:
logger.info("[sglang async engine] receive stop signal, stoping")
break
@@ -57,12 +57,12 @@ async def consumer(asyncio_queue, consumer_id, llm, request_complete_callback):
from roll.distributed.scheduler.protocol import DataProto
def process_sglang_output(token_ids, meta_info):
- # 线上正式使用
+ # Online production use
output_data = DataProto(meta_info=meta_info)
output_data.meta_info["output_token_ids"] = token_ids
request_complete_callback(data=output_data)
- # 本地调试使用
+ # Local debugging use
# request_complete_callback(meta_info['request_id'], token_ids)
logger.debug(f"worker_id:{consumer_id} request_id: {meta_info['request_id']} finish!")
diff --git a/roll/third_party/sglang/v046post4_patch/async_engine.py b/roll/third_party/sglang/v046post4_patch/async_engine.py
index 096b069e..e28de453 100644
--- a/roll/third_party/sglang/v046post4_patch/async_engine.py
+++ b/roll/third_party/sglang/v046post4_patch/async_engine.py
@@ -12,7 +12,7 @@ class SglangInputType(enum.Enum):
ABORT = enum.auto()
def list_endswith(lst, suffix):
- # 检查 lst 是否以 suffix 结尾
+ # Check if lst ends with suffix
return lst[-len(suffix):] == suffix if len(suffix) <= len(lst) else False
def trim_overlap_tokens(existing_tokens, new_chunk_tokens):
@@ -28,7 +28,7 @@ def trim_overlap_tokens(existing_tokens, new_chunk_tokens):
return new_chunk_tokens[max_overlap:]
-# 用于存放所有abort_rid_set
+# Used to store all abort_rid_set
abort_rid_set = set()
abort_lock = asyncio.Lock()
@@ -38,7 +38,7 @@ async def producer(thread_queue, asyncio_queue):
while True:
if not thread_queue.empty():
data = thread_queue.get()
- # 收到结束标记
+ # Received end marker
if data is None:
logger.info("[sglang async engine] receive stop signal, stoping")
break
@@ -57,12 +57,12 @@ async def consumer(asyncio_queue, consumer_id, llm, request_complete_callback):
from roll.distributed.scheduler.protocol import DataProto
def process_sglang_output(token_ids, meta_info):
- # 线上正式使用
+ # Online production use
output_data = DataProto(meta_info=meta_info)
output_data.meta_info["output_token_ids"] = token_ids
request_complete_callback(data=output_data)
- # 本地调试使用
+ # Local debugging use
# request_complete_callback(meta_info['request_id'], token_ids)
logger.debug(f"worker_id:{consumer_id} request_id: {meta_info['request_id']} finish!")
diff --git a/roll/third_party/vllm/vllm_0_7_3/llm.py b/roll/third_party/vllm/vllm_0_7_3/llm.py
index 1c99ab70..0b4c349e 100644
--- a/roll/third_party/vllm/vllm_0_7_3/llm.py
+++ b/roll/third_party/vllm/vllm_0_7_3/llm.py
@@ -147,7 +147,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
def clear_unfinished_requests(self):
self._run_engine(use_tqdm=True)
- # 参数同步接口
+ # Parameter synchronization interface
def setup_collective_group(self, *args, **kwargs):
self.collective_rpc(method="setup_collective_group", args=args, kwargs=kwargs)
diff --git a/roll/third_party/vllm/vllm_0_7_3/llm_engine.py b/roll/third_party/vllm/vllm_0_7_3/llm_engine.py
index ded9ec38..63fa5502 100644
--- a/roll/third_party/vllm/vllm_0_7_3/llm_engine.py
+++ b/roll/third_party/vllm/vllm_0_7_3/llm_engine.py
@@ -61,7 +61,7 @@ def update_worker_cls_config(cls, vllm_config: VllmConfig) -> None:
parallel_config.worker_cls = \
"vllm.worker.multi_step_worker.MultiStepWorker"
elif vllm_config.speculative_config:
- # TODO: 投机采样
+ # TODO: Speculative sampling
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
@@ -72,7 +72,7 @@ def update_worker_cls_config(cls, vllm_config: VllmConfig) -> None:
"vllm.worker.worker.Worker"
else:
if envs.VLLM_USE_V1:
- # TODO: 实现v1
+ # TODO: Implement v1
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
else:
diff --git a/roll/third_party/vllm/vllm_0_8_4/llm.py b/roll/third_party/vllm/vllm_0_8_4/llm.py
index c43bc829..eab40b60 100644
--- a/roll/third_party/vllm/vllm_0_8_4/llm.py
+++ b/roll/third_party/vllm/vllm_0_8_4/llm.py
@@ -166,7 +166,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
def clear_unfinished_requests(self):
self._run_engine(use_tqdm=True)
- # 参数同步接口
+ # Parameter synchronization interface
def setup_collective_group(self, *args, **kwargs):
self.collective_rpc(method="setup_collective_group", args=args, kwargs=kwargs)
diff --git a/roll/third_party/vllm/vllm_0_8_4/v1/async_llm.py b/roll/third_party/vllm/vllm_0_8_4/v1/async_llm.py
index e063da92..72f3d856 100644
--- a/roll/third_party/vllm/vllm_0_8_4/v1/async_llm.py
+++ b/roll/third_party/vllm/vllm_0_8_4/v1/async_llm.py
@@ -74,7 +74,7 @@ def offload_states(self, level=2):
self.reset_prefix_cache()
self.collective_rpc(method="offload_states")
- # 参数同步接口
+ # Parameter synchronization interface
def setup_collective_group(self, *args, **kwargs):
self.collective_rpc(method="setup_collective_group", args=args, kwargs=kwargs)
diff --git a/roll/utils/functionals.py b/roll/utils/functionals.py
index e95a17cd..46130a11 100644
--- a/roll/utils/functionals.py
+++ b/roll/utils/functionals.py
@@ -40,8 +40,8 @@ def delete_tensor_grad_visitor(obj, path):
def traverse_obj(value, visitor, path=()):
"""
- 遍历对象的所有属性,包括属性的属性,找到所有的 Tensor。
- :param value: 任意 Python 对象
+ Traverse all attributes of an object, including attributes of attributes, to find all Tensors.
+ :param value: Any Python object
:visitor
:path
"""
@@ -86,7 +86,7 @@ def divide_by_chunk_size(
data: Union[np.ndarray, TensorDict], chunk_sizes: List[int]
) -> List[Union[np.ndarray, TensorDict]]:
"""
- 将numpy数组按照chunks的大小切分
+ Split numpy array according to chunk sizes
"""
if not isinstance(data, (np.ndarray, TensorDict)):
raise TypeError("Input 'array' must be a numpy ndarray or a TensorDict.")
@@ -314,7 +314,7 @@ def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = T
def response_level_masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True):
"""Whiten values with masked values."""
- # 考虑response的影响?
+ # Consider the impact of response?
mean = masked_mean(values, mask, dim=-1)
var = masked_var(mean, mask)
mean = mean.mean()
@@ -475,7 +475,7 @@ def compute_token_reward(data: "DataProto", pipeline_config: RLVRConfig, kl_ctrl
action_mask=data.batch["response_mask"][:, 1:],
kl_penalty=pipeline_config.kl_penalty,
)
- # 是否添加token level kl
+ # Whether to add token level kl
if pipeline_config.add_token_level_kl and "ref_log_probs" in data.batch.keys():
beta = kl_ctrl.value
token_level_rewards = token_level_rewards - beta * kld
@@ -505,8 +505,8 @@ def compute_token_reward(data: "DataProto", pipeline_config: RLVRConfig, kl_ctrl
def reward_postprocess(data: "DataProto", pipeline_config: RLVRConfig, running_ctrl):
response_level_rewards = data.batch["response_level_rewards"].clone().detach()
response_level_metrics = {"critic/reward_clip_frac": 0.0}
- # 对reward进行处理: 可以选择不同的normalization方法
- # 使用group-based normalization (按prompt分组)
+ # Process rewards: can choose different normalization methods
+ # Use group-based normalization (group by prompt)
if pipeline_config.adv_estimator == "grpo" or pipeline_config.reward_norm == "group":
if pipeline_config.reward_shift:
response_level_rewards = group_reward_norm(
@@ -521,14 +521,14 @@ def reward_postprocess(data: "DataProto", pipeline_config: RLVRConfig, running_c
div_std=True,
)
- # 使用batch-based normalization (整个batch)
+ # Use batch-based normalization (entire batch)
elif pipeline_config.reward_norm == "batch":
if hasattr(pipeline_config, "reward_shift") and pipeline_config.reward_shift:
response_level_rewards = batch_reward_norm(response_level_rewards, div_std=False)
else:
response_level_rewards = batch_reward_norm(response_level_rewards, div_std=True)
- # 使用running statistics进行normalization
+ # Use running statistics for normalization
elif pipeline_config.reward_norm == "running":
running = running_ctrl["domain"]
running.update(response_level_rewards)
@@ -541,7 +541,7 @@ def reward_postprocess(data: "DataProto", pipeline_config: RLVRConfig, running_c
else:
response_level_rewards = (response_level_rewards - mean) / std
- # 对reward进行clip
+ # Clip reward
if pipeline_config.reward_clip:
reward_clip_frac = compute_clip_fraction(
values=response_level_rewards, clip_max=pipeline_config.reward_clip, clip_min=-pipeline_config.reward_clip
@@ -561,7 +561,7 @@ def get_sample_level_mask(data: "DataProto", pipeline_config: RLVRConfig):
batch_size = data.batch["response_mask"].size(0)
mask_metrics = {}
- # mask相关策略
+ # Mask related strategies
data.batch["origin_response_mask"] = data.batch["response_mask"].clone()
response_mask = data.batch["response_mask"][:, 1:].clone()
true_response_length = response_mask.sum(-1).float()
@@ -569,7 +569,7 @@ def get_sample_level_mask(data: "DataProto", pipeline_config: RLVRConfig):
final_sample_mask = torch.ones(batch_size, device=response_mask.device)
- # 1. max_len_mask: 过滤掉超过最大长度的样本
+ # 1. max_len_mask: Filter out samples exceeding maximum length
if pipeline_config.max_len_mask:
max_len_mask = (max_response_length != true_response_length).float()
final_sample_mask = final_sample_mask * max_len_mask
@@ -577,7 +577,7 @@ def get_sample_level_mask(data: "DataProto", pipeline_config: RLVRConfig):
else:
mask_metrics["actor/max_len_mask_ratio"] = 1.0
- # 2. difficulty_mask: 基于难度的过滤
+ # 2. difficulty_mask: Filter based on difficulty
if pipeline_config.difficulty_mask:
data = difficulty_mask(
data,
@@ -594,7 +594,7 @@ def get_sample_level_mask(data: "DataProto", pipeline_config: RLVRConfig):
else:
mask_metrics["actor/difficulty_mask_ratio"] = 1.0
- # 3. error_max_len_clip: 基于错误和长度的过滤
+ # 3. error_max_len_clip: Filter based on errors and length
if pipeline_config.error_max_len_clip:
scores = data.batch["scores"]
error_len_mask = ((scores == 0) & (true_response_length < pipeline_config.error_max_len_threshold)) | (
@@ -692,7 +692,7 @@ def compute_advantage(
data.batch["raw_advantages"] = advantages
if whiten_advantages:
- # TODO whiten过程中是否要考虑response的长度?
+ # TODO Should response length be considered during whitening process?
advantages = masked_whiten(values=advantages, mask=response_mask)
advantages = advantages * response_mask
@@ -725,8 +725,8 @@ def postprocess_generate(
from roll.distributed.scheduler.protocol import DataProto
if fill_eos_token:
- # yali: 如果output最后一个token不是pad_token_id,则替换成eos_token_id,
- # TODO: 需要消融这个变化的影响
+ # yali: If the last token of output is not pad_token_id, replace it with eos_token_id,
+ # TODO: Need to ablate the impact of this change
last_token_index = output.size(1) - 1
need_replace_mask = output[:, last_token_index] != pad_token_id
output[need_replace_mask, last_token_index] = eos_token_id
diff --git a/tests/distributed/strategy/generate/generate_pipeline.py b/tests/distributed/strategy/generate/generate_pipeline.py
index 2b7e0283..f7191f34 100644
--- a/tests/distributed/strategy/generate/generate_pipeline.py
+++ b/tests/distributed/strategy/generate/generate_pipeline.py
@@ -62,7 +62,7 @@ def __init__(self, pipeline_config: RLVRConfig):
def run(self):
global_step = 0
- # 计算tokens per second 系统吞吐
+ # Calculate tokens per second system throughput
tps_timer = _Timer(window_size=5)
metric_list = []
@@ -161,7 +161,7 @@ def __init__(self, pipeline_config: RLVRConfig):
def run(self):
global_step = 0
- # 计算tokens per second 系统吞吐
+ # Calculate tokens per second system throughput
tps_timer = _Timer(window_size=5)
metric_list = []
diff --git a/tests/math/test_math_dataset.py b/tests/math/test_math_dataset.py
index a45d49dc..e413f87a 100644
--- a/tests/math/test_math_dataset.py
+++ b/tests/math/test_math_dataset.py
@@ -12,7 +12,7 @@
dataset = load_dataset("json", data_files=dataset_path)["train"]
-# 加上format,然后转ids的func
+# Add format, then convert to ids function
def encode_function(data_i):
text_list = []
for instruct in data_i["prompt"]:
@@ -26,11 +26,11 @@ def encode_function(data_i):
return encodings
-# 处理数据
+# Process data
print(dataset)
dataset = dataset.map(encode_function, batched=True, desc="Encoding dataset")
print(dataset)
-# 过滤cutoff
+# Filter cutoff
dataset = dataset.filter(lambda data_i: len(data_i["input_ids"]) <= 512, desc="Filtering dataset")
print(dataset)
# ------
diff --git a/tests/models/cuda_mem/test_large_gemm.py b/tests/models/cuda_mem/test_large_gemm.py
index 64421a67..a1c2a5df 100644
--- a/tests/models/cuda_mem/test_large_gemm.py
+++ b/tests/models/cuda_mem/test_large_gemm.py
@@ -1,34 +1,34 @@
import torch
import torch.nn as nn
-# 定义参数
-vocab_size = 128 * 1000 # 词汇表大小
-intermediate_size = 1560 # 中间隐藏层大小
+# Define parameters
+vocab_size = 128 * 1000 # Vocabulary size
+intermediate_size = 1560 # Intermediate hidden layer size
batch_size = 1
-seq_len = 4096 # 假设序列长度
+seq_len = 4096 # Assumed sequence length
-# 创建一个随机隐层输出张量 (batch_size, seq_len, intermediate_size)
+# Create a random hidden layer output tensor (batch_size, seq_len, intermediate_size)
hidden_output = torch.randn(batch_size, seq_len, intermediate_size).cuda()
-# 创建最后一层线性层,将隐层大小映射到词汇表大小
+# Create the last linear layer, mapping hidden layer size to vocabulary size
# linear_layer = nn.Linear(intermediate_size, vocab_size).cuda()
-# 对每个时间步进行最后一层的计算,可以使用 reshape
-# 如果需要按照 seq_len 进行计算,可以使用 batched 为 (batch_size * seq_len, intermediate_size)
-# logits = linear_layer(hidden_output.view(-1, intermediate_size)) # 变形为 (batch_size * seq_len, intermediate_size)
+# Compute the last layer for each time step, can use reshape
+# If computing by seq_len, can use batched as (batch_size * seq_len, intermediate_size)
+# logits = linear_layer(hidden_output.view(-1, intermediate_size)) # Reshape to (batch_size * seq_len, intermediate_size)
-# 直接构造权重矩阵 W (intermediate_size, vocab_size)
+# Directly construct weight matrix W (intermediate_size, vocab_size)
weight_matrix = torch.randn(intermediate_size, vocab_size).cuda()
-# 计算 logits,进行矩阵乘法
-# 对每个时间步进行最后一层的计算,可以使用 reshape
+# Compute logits, perform matrix multiplication
+# Compute the last layer for each time step, can use reshape
logits = torch.matmul(hidden_output.view(-1, intermediate_size), weight_matrix)
-# 重新调整 logits 的形状为 (batch_size, seq_len, vocab_size)
+# Reshape logits to (batch_size, seq_len, vocab_size)
logits = logits.view(batch_size, seq_len, vocab_size)
-# 计算 softmax 以得到概率分布
-probabilities = nn.functional.softmax(logits, dim=-1) # 计算每个时间步的概率分布
+# Compute softmax to get probability distribution
+probabilities = nn.functional.softmax(logits, dim=-1) # Compute probability distribution for each time step
del logits, probabilities, weight_matrix, hidden_output
torch.cuda.empty_cache()
diff --git a/tests/third_party/megatron/test_offload_states.py b/tests/third_party/megatron/test_offload_states.py
index fa148669..35b56055 100644
--- a/tests/third_party/megatron/test_offload_states.py
+++ b/tests/third_party/megatron/test_offload_states.py
@@ -210,7 +210,7 @@ def create_mca_model(self):
"""
torchrun --standalone --nnodes=1 --nproc-per-node=2 -m pytest -s tests/third_party/megatron/test_offload_states.py
--s 显示stdout/err
+-s displays_stdout/err
"""
@@ -881,7 +881,7 @@ def run_model_fp32_optimizer(mca_model: TurboModelCreator, included_state, pin_m
@pytest.mark.parametrize("optimizer_type", [None, "dist_optimizer", "fp16", "fp32"])
def test_megatron_offload_states(included_state, pin_memory, non_blocking, optimizer_type):
"""
- 有四块非optimizer的显存未释放:
+ There are four blocks of non-optimizer GPU memory not released:
/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/transformer_engine/pytorch/module/base.py:58:get_workspace
/root/.local/lib/python3.10/site-packages/megatron/core/tensor_parallel/layers.py:413:forward
/root/.local/lib/python3.10/site-packages/megatron/core/models/gpt/gpt_model.py:249:forward