From 1e18794da1b32704a3767aced0fc1b53f7f9890c Mon Sep 17 00:00:00 2001 From: Jennifer Zhou Date: Thu, 12 Jun 2025 22:30:11 +0000 Subject: [PATCH] Translate to english --- .../scheduler/generate_scheduler.py | 42 +-- .../distributed/scheduler/reward_scheduler.py | 22 +- roll/models/model_providers.py | 4 +- roll/pipeline/agentic/agentic_config.py | 4 +- roll/pipeline/agentic/agentic_pipeline.py | 12 +- roll/pipeline/agentic/environment_worker.py | 56 ++-- roll/pipeline/base_worker.py | 28 +- roll/pipeline/rlvr/actor_worker.py | 16 +- .../rewards/code_sandbox_reward_worker.py | 24 +- .../crossthinkqa_rule_reward_worker.py | 38 +-- .../rewards/general_val_rule_reward_worker.py | 36 +-- .../rlvr/rewards/ifeval_rule_reward_worker.py | 246 +++++++++--------- .../rlvr/rewards/llm_judge_reward_worker.py | 2 +- .../rlvr/rewards/math_rule_reward_worker.py | 10 +- roll/pipeline/rlvr/rlvr_config.py | 2 +- roll/pipeline/rlvr/rlvr_pipeline.py | 26 +- .../deepspeed/offload_states_patch.py | 22 +- .../megatron/offload_states_patch.py | 20 +- .../sglang/v043post4_patch/async_engine.py | 10 +- .../sglang/v046post4_patch/async_engine.py | 10 +- roll/third_party/vllm/vllm_0_7_3/llm.py | 2 +- .../third_party/vllm/vllm_0_7_3/llm_engine.py | 4 +- roll/third_party/vllm/vllm_0_8_4/llm.py | 2 +- .../vllm/vllm_0_8_4/v1/async_llm.py | 2 +- roll/utils/functionals.py | 34 +-- .../strategy/generate/generate_pipeline.py | 4 +- tests/math/test_math_dataset.py | 6 +- tests/models/cuda_mem/test_large_gemm.py | 30 +-- .../megatron/test_offload_states.py | 4 +- 29 files changed, 359 insertions(+), 359 deletions(-) diff --git a/roll/distributed/scheduler/generate_scheduler.py b/roll/distributed/scheduler/generate_scheduler.py index 9e8d3bd3..0483f2d3 100644 --- a/roll/distributed/scheduler/generate_scheduler.py +++ b/roll/distributed/scheduler/generate_scheduler.py @@ -171,7 +171,7 @@ def generate(self, data: DataProto, actor_cluster: Union[Any, Cluster], pipeline def get_available_dp_rank(self): while True: - # 负载均衡逻辑,期望各dp 正在处理的条数基本接近 + # Load balancing logic, expect the number of items being processed by each dp to be roughly similar sorted_ranks = sorted( self.load_balance_coordinator.keys(), key=lambda rank: (self.load_balance_coordinator[rank], rank) ) @@ -205,26 +205,26 @@ def generate_opt_level_1(self, data: DataProto): ) self.cluster.start_server(data=DataProto(meta_info=data.meta_info), blocking=True) - # 分发数据至收到target rollout 完成 - # 无限循环,把所有的response发送给dp worker + # Distribute data until target rollout completion + # Infinite loop, send all responses to dp workers send_request_count = 0 request_refs = [] data_index_counter = itertools.count() last_alive_check = time.time() while not self.is_completed: - # 探测dp worker是否存活,dp worker的server thread可能由于异常退出,造成hang + # Check if dp worker is alive, dp worker's server thread may exit due to exceptions, causing hang current_time = time.time() if current_time - last_alive_check >= self.alive_check_interval: self.cluster.add_request(command=GenerateRequestType.ALIVE_CHECK, data=DataProto()) last_alive_check = current_time if send_request_count < data.batch.batch_size[0]: - # 取一个可以发送request的dp worker + # Get a dp worker that can send requests dp_rank = next(self.get_available_dp_rank()) - # 还有数据需要发送, 取需要发送的数据 - # request_id 全局递增,否则vllm/sglang scheduler状态不对 + # Still have data to send, get the data that needs to be sent + # request_id increments globally, otherwise vllm/sglang scheduler state is incorrect request_id = next(self.request_counter) data_index = next(data_index_counter) request_data = collate_fn([self.data[data_index]]) @@ -235,7 +235,7 @@ def generate_opt_level_1(self, data: DataProto): ].item() self.request_id_2_dp_rank[request_data.meta_info["request_id"]] = dp_rank self.prompt_id_2_request_ids[prompt_id].add(request_data.meta_info["request_id"]) - # 需要注意上面的调用顺序, report_response中会更新request_id索引dp_rank,所以这里需要最后add request_id + # Need to pay attention to the calling order above, report_response will update request_id index dp_rank, so need to add request_id last request_data.meta_info["response_callback_fn"] = self.response_callback_fn request_data.meta_info["generation_config"] = data.meta_info["generation_config"] request_refs.append( @@ -394,7 +394,7 @@ def set_scheduler( state: Dict[str, Any] = None, ): """ - GenerateScheduler可以由多个实例,不再局限于单例 + GenerateScheduler can have multiple instances, no longer limited to singleton """ self.actor_cluster = actor_cluster self.reward_clusters = reward_clusters @@ -459,9 +459,9 @@ def reset_status(self): def get_batch(self, data: DataProto, batch_size: int) -> DataProto: """ - 从dataset里,按给定策略sample batch - 1. 常规无过滤 - 2. 动态过滤 + Sample batch from dataset using given strategy + 1. Regular without filtering + 2. Dynamic filtering """ self.batch_size = batch_size self.reset_status() @@ -522,7 +522,7 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto: f"used queries: {query_use_count} query_filter_count: {self.query_filter_count} " f"response_filter_count: {self.response_filter_count}" ) - # TODO: 这里 len(collect_data) > rollout_batch_size, 可以尝试动态扩大batch_size + # TODO: Here len(collect_data) > rollout_batch_size, can try dynamically expanding batch_size batch = DataProto.concat(collect_data[: self.batch_size * num_return_sequences]) batch.meta_info["metrics"] = { f"scheduler/query_filter_count": self.query_filter_count, @@ -531,7 +531,7 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto: f"scheduler/query_use_count": query_use_count, } - # 统计全部response metrics + # Count all response metrics metrics = {} for domain, response_batches in self.response_cache.items(): response_batch = DataProto.concat(response_batches[:]) @@ -548,8 +548,8 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto: @ray.method(concurrency_group="multi_thread") def report_response(self, data: DataProto): """ - 这里需要考虑多线程数据访问 - data 返回可能有多条的 + Need to consider multi-threaded data access here + Data return may have multiple entries """ try: request_id = data.meta_info["request_id"] @@ -570,15 +570,15 @@ def report_response(self, data: DataProto): return # call reward - # reward worker得能支持单条数据计算, dynamic sampling对需要batch计算reward的需要注意... - # 多域的时候,llm as judge, 需要单独为reward worker分配gpu + # reward worker must support single data calculation, dynamic sampling needs attention for batch reward calculation... + # In multi-domain cases, llm as judge, need to allocate gpu separately for reward worker rewards: DataProto = ray.get(reward_worker.compute_rewards.remote(batch)) batch.union(rewards) response_buffers: List[DataProto] = [] batch_expanded = [batch[[idx]] for idx in range(output_count)] - # response_filter, 不太需要response filter + # response_filter, don't really need response filter for batch_item in batch_expanded: if self.response_filter_fn(batch_item, self.pipeline_config): response_buffers.append(batch_item) @@ -706,7 +706,7 @@ def expand_requests(self, data: DataProto): return target_requests def check_worker_alive(self, cluster): - # 探测dp worker是否存活,dp worker的server thread可能由于异常退出,造成hang + # Check if dp worker is alive, dp worker's server thread may exit due to exceptions, causing hang current_time = time.time() if current_time - self.last_alive_check >= self.alive_check_interval: cluster.add_request(command=GenerateRequestType.ALIVE_CHECK, data=DataProto()) @@ -727,7 +727,7 @@ def check_send_new_request(self) -> bool: def get_available_dp_rank(self): while True: - # 负载均衡逻辑,期望各dp 正在处理的条数基本接近 + # Load balancing logic, expect the number of items being processed by each dp to be roughly similar sorted_ranks = sorted( self.load_balance_coordinator.keys(), key=lambda rank: (self.load_balance_coordinator[rank], rank) ) diff --git a/roll/distributed/scheduler/reward_scheduler.py b/roll/distributed/scheduler/reward_scheduler.py index 120bb78d..8703ecd9 100644 --- a/roll/distributed/scheduler/reward_scheduler.py +++ b/roll/distributed/scheduler/reward_scheduler.py @@ -15,14 +15,14 @@ @ray.remote class RewardScheduler: """ - reward 服务化和generate不同, request接口: - reward scheduler需要解决的是不同域的sample的reward计算问题, 不需要实现request粒度的接口; - 并且reward计算和vllm不同,vllm可以continue batch,所以可以动态add request, reward不行, - 直接rpc调用reward_cluster.compute_rewards即可(使用rpc方式调用,可以增加reward的数量,增大并发处理能力) + Reward service is different from generation, request interface: + Reward scheduler needs to solve the reward calculation problem for samples from different domains, no need to implement request-level interface; + And reward calculation is different from vllm, vllm can continue batch, so it can dynamically add requests, reward cannot, + directly use rpc to call reward_cluster.compute_rewards (using rpc method, can increase the number of rewards, increase concurrent processing capacity) - reward scheduler需要解决的问题: - 按domain路由reward - dp dispatch 均分/不足dp_size 的限制 + Problems that reward scheduler needs to solve: + Route rewards by domain + dp dispatch load balancing/insufficient dp_size limitations """ def __init__(self): @@ -32,13 +32,13 @@ def __init__(self): def compute_rewards(self, data: DataProto, reward_clusters: Dict[str, Any], pipeline_config) -> DataProto: """ - 保序返回rewards + Return rewards in order """ self.pipeline_config = pipeline_config self.reward_clusters = reward_clusters data.batch["prompt_id"] = torch.arange(data.batch.batch_size[0], device=data.batch.device) - # 按domain group by data + # Group data by domain grouped_data: Dict[str, DataProto] = data.group_by("domain") domain_rewards_refs: Dict[str, List[ray.ObjectRef]] = defaultdict(list) @@ -51,8 +51,8 @@ def compute_rewards(self, data: DataProto, reward_clusters: Dict[str, Any], pipe rewards_list: List[DataProto] = [] for domain, domain_rewards_ref in domain_rewards_refs.items(): - # 各reward的输出schema要求一致 - # reward worker compute_rewards 接口返回结果保序 + # All rewards require consistent output schema + # Reward worker compute_rewards interface returns results in order if domain not in grouped_data.keys(): continue domain_rewards: DataProto = DataProto.materialize_concat(data_refs=domain_rewards_ref) diff --git a/roll/models/model_providers.py b/roll/models/model_providers.py index 2c147ed5..ed3ffe54 100644 --- a/roll/models/model_providers.py +++ b/roll/models/model_providers.py @@ -376,9 +376,9 @@ def default_reward_model_provider( is_trainable: Optional[bool] = False, ): """ - model.forward 遵循TokenClassifierOutput 协议 + model.forward follows TokenClassifierOutput protocol class TokenClassifierOutput(ModelOutput): - logits: torch.FloatTensor # 必须要有 + logits: torch.FloatTensor # Required loss: Optional[torch.FloatTensor] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None diff --git a/roll/pipeline/agentic/agentic_config.py b/roll/pipeline/agentic/agentic_config.py index 10999450..e49eb140 100644 --- a/roll/pipeline/agentic/agentic_config.py +++ b/roll/pipeline/agentic/agentic_config.py @@ -45,7 +45,7 @@ class EnvManagerConfig(WorkerConfig): def __post_init__(self): """ - 根据es config计算world_size + Calculate world_size based on es config """ self.world_size = self.env_groups * self.group_size self.env_configs: Optional[Dict[int, Dict]] = None @@ -266,7 +266,7 @@ def set_max_steps(self, max_steps: int): self.critic.training_args.per_device_train_batch_size * self.critic.training_args.gradient_accumulation_steps ) - # 没有除dp_size,需要在分布式环境初始化后再除 + # Not divided by dp_size, need to divide after distributed environment initialization self.actor_train.training_args.max_steps = max_steps * ( self.rollout_batch_size * self.actor_infer.generating_args.num_return_sequences diff --git a/roll/pipeline/agentic/agentic_pipeline.py b/roll/pipeline/agentic/agentic_pipeline.py index 54ea5c4f..3b991fba 100644 --- a/roll/pipeline/agentic/agentic_pipeline.py +++ b/roll/pipeline/agentic/agentic_pipeline.py @@ -107,7 +107,7 @@ def __init__(self, pipeline_config: AgenticConfig): @torch.no_grad() def run(self): - # 计算tokens per second 系统吞吐 + # Calculate tokens per second system throughput tps_timer = _Timer(window_size=5) for global_step in range(self.pipeline_config.max_steps): @@ -191,8 +191,8 @@ def run(self): metrics.update(reduce_metrics(old_log_probs.meta_info.pop("metrics", {}))) metrics["time/old_log_probs_values"] = cal_old_logpb_timer.last - # 要按group by处理reward - # 可以tag(env_type)/traj_group_id(group)/batch(rollout_batch)... group_by计算reward/adv + # Need to process rewards by group + # Can group by tag(env_type)/traj_group_id(group)/batch(rollout_batch)... to calculate reward/adv batch.batch["prompt_id"] = torch.arange(batch.batch.batch_size[0], device=batch.batch.device) with Timer(name="adv", logger=None) as timer: grouping = self.pipeline_config.reward_normalization.grouping @@ -228,7 +228,7 @@ def run(self): batch = DataProto.concat(batch_list) batch.reorder(indices=torch.argsort(batch.batch["prompt_id"])) batch.pop("prompt_id") - # advantage是全局batch计算,还是group内计算? + # Is advantage calculated globally across batch or within groups? batch = compute_advantage( data=batch, gamma=self.pipeline_config.gamma, @@ -314,8 +314,8 @@ def run(self): def compute_data_metrics(batch): - # token_level_scores 是reward model给每个token的打分,可能经过了norm/clip - # score 为env的reward,raw value + # token_level_scores are scores given by reward model to each token, possibly after norm/clip + # score is the environment reward, raw value sequence_score = batch.batch["scores"].sum(-1) sequence_reward = batch.batch["token_level_rewards"].sum(-1) advantages = batch.batch["advantages"] diff --git a/roll/pipeline/agentic/environment_worker.py b/roll/pipeline/agentic/environment_worker.py index 80493b35..6576eb11 100644 --- a/roll/pipeline/agentic/environment_worker.py +++ b/roll/pipeline/agentic/environment_worker.py @@ -51,7 +51,7 @@ def get_masks_and_scores( ): """ input_ids: shape (bsz, seq_len) - all_scores: list[list[float], 存储每个env每轮的reward + all_scores: list[list[float]], stores reward for each env per turn Get loss mask that only learns between <|im_start|>assistant and <|im_end|>. Currently only supports qwen. NOTE: important! This assumes that the input_ids starts with system and then user & assistant in alternative ways NOTE: important! input_ids is left pad @@ -65,7 +65,7 @@ def get_masks_and_scores( non_prompt_mask = turn_indicators > 2 # learns everything after system prompt + user prompts # turn text: '<|im_start|>assistant\nRight<|im_end|>' - # <|im_start|>assistant\n 应该mask掉才对,保留<|im_end|> + # <|im_start|>assistant\n should be masked, keep <|im_end|> for idx, scores in enumerate(zip_longest(*all_scores, fillvalue=0)): turn_indicator = idx * 2 + 3 # 0: pad. 1: system. 2+2n: user. 3+2n: assistant turn_start_position = (input_ids == turn_start_token) & (turn_indicators == turn_indicator) @@ -129,19 +129,19 @@ def left_pad_2_right( class EnvironmentWorker(Worker): """ - 1. 一个EnvironmentWorker(进程)持有一个env实例: 执行env.reset, env.step, 管理rollout的状态 - group trajectory表达: group内的init state一致,依赖env_config 中的seed来控制, 一个group内env 对应episode的seed一致 - 不采用持有envs的原因是,envs需要管理一组env的交互,增加描述的复杂性 - 2. 持有infer_cluster ref, 用于async generate - 3. run_rollout_loop, 持续rollout trajectory, 将done的trajectory回传到output_queue - - 承担EnvStateManager的history收集功能 - 一个group内的env reset进度应该一致 - - TODO: env并行方式后续改成进程+线程并行:目的解决一个env占用一个进程对系统资源的开销 - - 一个EnvironmentWorker持有n个EnvStateManager - - EnvStateManager管理一个env的rollout loop - - EnvStateManager.run_rollout_loop,运行在n个线程里 + 1. One EnvironmentWorker (process) holds an env instance: executes env.reset, env.step, manages rollout state + group trajectory representation: init state within group is consistent, controlled by seed in env_config, env within a group corresponds to consistent episode seed + reason for not holding envs: envs need to manage interaction of a group of envs, increasing complexity of description + 2. Holds infer_cluster ref, used for async generate + 3. run_rollout_loop, continuously rollout trajectory, returns done trajectory to output_queue + + Takes on EnvStateManager's history collection functionality + Env reset progress within a group should be consistent + + TODO: env parallelism will later be changed to process+thread parallelism: purpose is to solve the overhead of one env occupying one process on system resources + - One EnvironmentWorker holds n EnvStateManagers + - EnvStateManager manages one env's rollout loop + - EnvStateManager.run_rollout_loop, runs in n threads TODO: GiGPO: https://arxiv.org/abs/2505.10978 """ @@ -276,7 +276,7 @@ def generate(self, env_output: Dict): lm_output: DataProto = ray.get(self.generate_scheduler.generate_one_request.remote(data=gen_batch)) if lm_output is not None: - # 未被abort + # not aborted gen_batch.meta_info.pop("generation_config") gen_batch.meta_info.pop("response_callback_fn") lm_input = lm_input.repeat(repeat_times=generation_config["num_return_sequences"]) @@ -286,11 +286,11 @@ def generate(self, env_output: Dict): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def run_rollout_loop(self, data: DataProto): """ - 1. 每次调用run_rollout_loop, - 会持续的play episode, 直到收到采集完成的command - 需要重置seed, 确保每个group的seed一致 - episode_id 置0 - seed更新逻辑: + 1. Each call to run_rollout_loop, + continuously play episode, until receiving collection completion command + need to reset seed, ensure seed consistency within each group + episode_id set to 0 + seed update logic: group_seed = seed + group_seed episode_seed = group_seed + episode_id @@ -379,9 +379,9 @@ def get_env_input(self, lm_output: DataProto) -> Dict: def formulate_rollouts(self): """ - 1. 每个env的trajectory 是一个rollout - 2. 每个rollout 是一个List[Dict] - 3. 每个Dict 是一个step的信息 + 1. Each env's trajectory is a rollout + 2. Each rollout is a List[Dict] + 3. Each Dict is information for one step """ llm_input_texts, messages_list = self._format_messages( env_output=self.rollout_cache, prepare_for_update=True, use_raw_llm_response=False @@ -436,7 +436,7 @@ def formulate_rollouts(self): llm_inputs.batch["non_prompt_mask"] = non_prompt_mask llm_inputs.batch["response_mask"] = non_prompt_mask if self.pipeline_config.enable_response_mask: - # 只使用llm的response mask,不包含环境的state + # only use llm's response mask, not including environment's state llm_inputs.batch["response_mask"] = response_mask first_true_indices = non_prompt_mask.int().argmax(dim=1) no_true_mask = ~non_prompt_mask.any(dim=1) @@ -601,7 +601,7 @@ def _format_messages(self, env_output: Dict, prepare_for_update: bool, use_raw_l ) if "llm_raw_response" in content: # yali: using the raw response will cause continuous crashes: https://aliyuque.antfin.com/mdl-team/traning/wmne4oyxg4dozwia - # 改成actions合理吗? + # is it reasonable to change to actions? messages.append( { "role": "assistant", @@ -627,12 +627,12 @@ def _format_messages(self, env_output: Dict, prepare_for_update: bool, use_raw_l else: text += "" # force the LLM to answer - # TODO: 应该没有必要,注意处理mask + # TODO: should not be necessary, pay attention to handling mask text = text.replace("<|im_end|>\n", "<|im_end|>") return [text], [messages] def _init_prefix_lookup(self): - # TODO: 这里并不合理 + # TODO: this is not reasonable here prefix_lookup = {} prefixes = {} env_config_lookup = {} diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index 94106271..9fca480d 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -163,7 +163,7 @@ def generate(self, data: DataProto): @torch.no_grad() def start_server(self, data: DataProto): """ - 解决dp generate的长尾问题,async+ load balance + Solve dp generate long tail problem, async + load balance """ global_step = data.meta_info.get("global_step", 0) is_offload_states = data.meta_info.get("is_offload_states", True) @@ -235,9 +235,9 @@ def compute_log_probs(self, data: DataProto): def forward_func_log_probs(self, data: DataProto, output_tensor: torch.Tensor): """ - forward func 接口定义: - data: DataProto, 由forward_step透传 - output_tensor: torch.Tensor, model.forward()的输出Tensor + forward func interface definition: + data: DataProto, passed through from forward_step + output_tensor: torch.Tensor, output tensor from model.forward() """ log_probs = self.strategy.op_compute_log_probs( logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] @@ -247,9 +247,9 @@ def forward_func_log_probs(self, data: DataProto, output_tensor: torch.Tensor): def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ - loss func接口定义: - data: DataProto, 由train_step透传 - output_tensor: torch.Tensor, model.forward()的输出Tensor + loss func interface definition: + data: DataProto, passed through from train_step + output_tensor: torch.Tensor, output tensor from model.forward() """ response_mask = data.batch["response_mask"][:, 1:].long() @@ -324,7 +324,7 @@ def do_checkpoint(self, global_step): with Timer("do_checkpoint") as total_timer: ckpt_id = f"checkpoint-{global_step}" - # actor train是直接存在save dir目录下的,其他role是存在save_dir/cluster_name下的 + # actor train is saved directly in save dir directory, other roles are saved in save_dir/cluster_name save_dir = os.path.join(self.pipeline_config.output_dir, self.worker_name, ckpt_id) self.logger.info(f"save checkpoint-{global_step} to {save_dir}") @@ -341,10 +341,10 @@ def do_checkpoint(self, global_step): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def add_request(self, command, data: DataProto): """ - data req meta_info里需要包含: + data req meta_info needs to contain: request_id: str response_callback_fn: callable - generation_config, 按request设置 + generation_config, set by request """ if command == GenerateRequestType.ALIVE_CHECK: if self.thread_server is not None: @@ -474,9 +474,9 @@ def train_step(self, data: DataProto): def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ - loss func接口定义: - data: DataProto, 由train_step透传 - output_tensor: torch.Tensor, model.forward()的输出Tensor + loss func interface definition: + data: DataProto, passed through from train_step + output_tensor: torch.Tensor, output tensor from model.forward() """ response_mask = data.batch["response_mask"][:, 1:] old_values = data.batch["values"] @@ -534,7 +534,7 @@ def do_checkpoint(self, global_step): class RewardWorker(Worker): """ - Reward Model 使用 AutoModelForSequenceClassification 协议 + Reward Model uses AutoModelForSequenceClassification protocol """ def __init__(self, worker_config: WorkerConfig): diff --git a/roll/pipeline/rlvr/actor_worker.py b/roll/pipeline/rlvr/actor_worker.py index 71bd3834..ab328f03 100644 --- a/roll/pipeline/rlvr/actor_worker.py +++ b/roll/pipeline/rlvr/actor_worker.py @@ -10,9 +10,9 @@ class ActorWorker(BaseActorWorker): def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ - loss func接口定义: - data: DataProto, 由train_step透传 - output_tensor: torch.Tensor, model.forward()的输出Tensor + loss func interface definition: + data: DataProto, passed through from train_step + output_tensor: torch.Tensor, output tensor from model.forward() """ response_mask = data.batch["response_mask"][:, 1:].long() @@ -122,12 +122,12 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): def compute_sample_weights(self, data: DataProto, response_mask: torch.Tensor): """ - 可以基于难度和长度的样本权重 + Sample weights can be based on difficulty and length """ batch_size = response_mask.shape[0] sample_weights = torch.ones(batch_size, device=response_mask.device) - # 1. 基于难度的权重 - 例如:难度越高,权重越大 + # 1. Difficulty-based weights - example: higher difficulty, higher weight if self.pipeline_config.difficulty_loss_weight and "difficulty" in data.non_tensor_batch: try: difficulty = data.non_tensor_batch["difficulty"] @@ -139,12 +139,12 @@ def compute_sample_weights(self, data: DataProto, response_mask: torch.Tensor): difficulty_weights = 0.5 + 1.5 * norm_difficulty sample_weights = sample_weights * difficulty_weights except Exception as e: - self.logger.warning(f"跳过difficulty权重计算:{str(e)}") + self.logger.warning(f"Skip difficulty weight calculation: {str(e)}") - # 2. 基于长度的权重 - 例如:长度越长,权重越小 + # 2. Length-based weights - example: longer length, smaller weight response_lengths = response_mask.sum(dim=1).float() if self.pipeline_config.length_loss_weight: - # 同样归一化长度到[0.5, 2.0]范围 + # Similarly normalize length to [0.5, 2.0] range norm_lengths = (response_lengths - response_lengths.min()) / ( response_lengths.max() - response_lengths.min() + 1e-8 ) diff --git a/roll/pipeline/rlvr/rewards/code_sandbox_reward_worker.py b/roll/pipeline/rlvr/rewards/code_sandbox_reward_worker.py index 2aadc5cb..c90b9448 100644 --- a/roll/pipeline/rlvr/rewards/code_sandbox_reward_worker.py +++ b/roll/pipeline/rlvr/rewards/code_sandbox_reward_worker.py @@ -79,7 +79,7 @@ def __init__(self, sandbox_url: str): self.logger = get_logger() def check_format(self, prompt_id: str, text: str): - # 检查格式是否满足要求:回答中必须包含""和代码片段 + # Check if format meets requirements: response must contain "" and code blocks has_think_tag = "" in text has_code_block = "```" in text @@ -92,7 +92,7 @@ def check_format(self, prompt_id: str, text: str): return format def extract_code_blocks(self, prompt, text: str, case_type: str = "input"): - """提取代码块""" + """Extract code blocks""" if "<|begin_of_solution|>" in text: text = text.split("<|begin_of_solution|>")[-1].strip() if "" in text: @@ -122,7 +122,7 @@ def extract_code_blocks(self, prompt, text: str, case_type: str = "input"): return codes, langs, "" def format_sandbox_test(self, test_code, code_language, case_type, test_cases) -> Optional[List[Dict]]: - """格式化sandbox测试用例""" + """Format sandbox test cases""" test_cases_final = [] if code_language is None or code_language == "": # TDO detect programming language @@ -249,7 +249,7 @@ def sandbox_test(self, prompt_id: str, test_cases: List[Dict]) -> Tuple[List[Dic return result def sanbox_result_judge(self, test_cases: List[Dict], sandbox_results: List[Dict]) -> int: - """判断测试用例通过数量""" + """Judge the number of test cases passed""" pass_test_number = 0 error_types = [] for i, responses in enumerate(sandbox_results): @@ -304,7 +304,7 @@ def single_code_test( prompt: str = "", flag: int = 0, ): - """单条代码测试""" + """Single code test""" info = { "global_step": global_step, "prompt_id": prompt_id, @@ -320,11 +320,11 @@ def single_code_test( "error_info": [], } - # 判断格式是否满足要求 + # Check if format meets requirements format_validation = self.check_format(prompt_id, code) info["format_validation"] = format_validation start_time = time.time() - # 抽取代码片段 + # Extract code blocks codes, code_langs, error_info = self.extract_code_blocks(prompt, code, case_type) if error_info != "" or len(codes) == 0: info["error_info"] = ["extract_code_blocks error"] @@ -337,13 +337,13 @@ def single_code_test( # TDO detect programming language code_language = "python" - # 格式化sandbox测试用例 + # Format sandbox test cases test_cases, error_info = self.format_sandbox_test(test_code, code_language, case_type, test_cases) if error_info != "" or test_cases == None: info["error_info"] = ["format_sandbox_test error"] return info - # 调用sandbox测试 + # Call sandbox testing succeed_test_cases, responses = self.sandbox_test(prompt_id, test_cases) if not responses or len(succeed_test_cases) == 0: info["error_info"] = ["sandbox error"] @@ -351,7 +351,7 @@ def single_code_test( info["validation"] = 0 return info - # 判断sandbox测试结果 + # Judge sandbox test results pass_test_number, error_types = self.sanbox_result_judge(succeed_test_cases, responses) time_duration = time.time() - start_time @@ -489,8 +489,8 @@ def cal_local_test(prompt_id, response, test_cases, func_name=None, num_process_ class CodeSandboxRewardWorker(Worker): """ - (x)Reward Model 使用 AutoModelForSequenceClassification 协议 - 面向code的sandbox 单测的 reward model + (x)Reward Model uses AutoModelForSequenceClassification protocol + Code-oriented sandbox unit test reward model """ def __init__(self, worker_config: RewardConfig): diff --git a/roll/pipeline/rlvr/rewards/crossthinkqa_rule_reward_worker.py b/roll/pipeline/rlvr/rewards/crossthinkqa_rule_reward_worker.py index 254849c6..e5e2f92b 100644 --- a/roll/pipeline/rlvr/rewards/crossthinkqa_rule_reward_worker.py +++ b/roll/pipeline/rlvr/rewards/crossthinkqa_rule_reward_worker.py @@ -1,4 +1,4 @@ -# 导入必要的库和模块 +# Import necessary libraries and modules from functools import partial from typing import Optional, Union, Iterator import json @@ -21,7 +21,7 @@ from roll.utils.logging import get_logger -logger = get_logger() # 获取日志记录器实例 +logger = get_logger() # Get logger instance def get_response_length_reward(min_len, max_len): @@ -70,19 +70,19 @@ def repetition_penalty_reward(response, **kwargs) -> float: def extract_after_last_think(input_string, end_think=""): """ - 提取输入字符串中最后一个"end_think"标签之后的内容, - 并移除结果字符串开头的所有换行符。 + Extract content after the last "end_think" tag in the input string, + and remove all newlines at the beginning of the result string. Args: - input_string: 原始字符串。 + input_string: Original string. Returns: - 提取并处理后的字符串。如果未找到"end_think"标签,则返回空字符串。 + Extracted and processed string. Returns empty string if "end_think" tag not found. """ last_index = input_string.rfind(end_think) if last_index == -1: - return input_string # 或者根据需要返回 None 或原始字符串 + return input_string # Or return None or original string as needed start_pos = last_index + len(end_think) extracted_part = input_string[start_pos:] @@ -97,9 +97,9 @@ def crossthinkqa_reward_fn(response, ground_truth, reward_type): correct_flag = False # 1. format - # 找到所有的 \\boxed{} 匹配项 + # Find all \\boxed{} matches box_matches = re.findall(r"\\boxed\{([^}]+)\}", response) - # 如果没有找到 \\boxed{} 则返回 None + # If no \\boxed{} found, return None if not box_matches: lower_response = response.lower() last_answer_index = lower_response.rfind("answer is") @@ -107,7 +107,7 @@ def crossthinkqa_reward_fn(response, ground_truth, reward_type): extracted_answer = response else: extracted_answer = response[last_answer_index + 9 :] - # 获取最后一个 \\boxed{} 的内容 + # Get content of the last \\boxed{} else: format_flag = True extracted_answer = box_matches[-1] @@ -145,8 +145,8 @@ def crossthinkqa_reward_fn(response, ground_truth, reward_type): class CrossThinkQARuleRewardWorker(Worker): """ - 一个示例 Reward Worker,用于执行 ifeval 验证并把每个 func 的结果放到 output.tensors 中。 - 在此示例里,ground_truths的str + Example Reward Worker for executing ifeval validation and putting each func result into output.tensors. + In this example, ground_truths string """ def __init__(self, worker_config: WorkerConfig): @@ -179,8 +179,8 @@ def compute_rewards(self, data: DataProto): scores = [] repetition_penalty_rewards = [] response_length_rewards = [] - format_values = [] # 格式正确的value(严格要求有\boxed{}) - correct_values = [] # 答案正确的value(用更宽松的规则提取) + format_values = [] # Format-correct values (strictly require \boxed{}) + correct_values = [] # Answer-correct values (use more lenient extraction rules) for i, (resp_tokens, ground_truth, tag, prompt) in enumerate( zip(data.batch["responses"], ground_truths, tags, prompts) @@ -200,13 +200,13 @@ def compute_rewards(self, data: DataProto): format_value = 1 if format_flag else 0 correct_value = 1 if correct_flag else 0 - # score应该为0或者1,标志模型回复的对错 + # score should be 0 or 1, indicating model response correctness if crossthinkqa_reward > 0: score = 1.0 else: score = 0.0 - # 存到 crossthinkqa_rewards + # Store to crossthinkqa_rewards crossthinkqa_rewards.append(crossthinkqa_reward) scores.append(score) repetition_penalty_rewards.append(repetition_penalty_reward) @@ -248,8 +248,8 @@ def compute_rewards(self, data: DataProto): format_values = torch.tensor(format_values, dtype=torch.float16) correct_values = torch.tensor(correct_values, dtype=torch.float16) - # 5) 将这些张量打包进同一个字典 - # TODO: 不同的reward worker的output是否需要统一output,或者有没有自适应的办法,避免在新增监控量时每个worker都需要修改 + # 5) Pack these tensors into the same dictionary + # TODO: Whether different reward workers' outputs need unified output, or if there's an adaptive approach to avoid modifying every worker when adding new monitoring metrics output_tensors = { "token_level_rewards": token_level_rewards, "response_level_rewards": response_level_rewards, @@ -260,6 +260,6 @@ def compute_rewards(self, data: DataProto): # "correct_values": correct_values } - # 6) 用 DataProto.from_dict(...) 构造返回值 + # 6) Use DataProto.from_dict(...) to construct return value output = DataProto.from_dict(tensors=output_tensors) return output diff --git a/roll/pipeline/rlvr/rewards/general_val_rule_reward_worker.py b/roll/pipeline/rlvr/rewards/general_val_rule_reward_worker.py index e0268e46..bbbc6db2 100644 --- a/roll/pipeline/rlvr/rewards/general_val_rule_reward_worker.py +++ b/roll/pipeline/rlvr/rewards/general_val_rule_reward_worker.py @@ -1,4 +1,4 @@ -# 导入必要的库和模块 +# Import necessary libraries and modules from functools import partial from typing import Optional, Union, Iterator import json @@ -21,24 +21,24 @@ from roll.utils.logging import get_logger -logger = get_logger() # 获取日志记录器实例 +logger = get_logger() # Get logger instance def extract_after_last_think(input_string, end_think=""): """ - 提取输入字符串中最后一个"end_think"标签之后的内容, - 并移除结果字符串开头的所有换行符。 + Extract content after the last "end_think" tag in the input string, + and remove all newlines at the beginning of the result string. Args: - input_string: 原始字符串。 + input_string: Original string. Returns: - 提取并处理后的字符串。如果未找到"end_think"标签,则返回空字符串。 + Extracted and processed string. Returns empty string if "end_think" tag not found. """ last_index = input_string.rfind(end_think) if last_index == -1: - return input_string # 或者根据需要返回 None 或原始字符串 + return input_string # Or return None or original string as needed start_pos = last_index + len(end_think) extracted_part = input_string[start_pos:] @@ -52,9 +52,9 @@ def single_choice_reward(response, ground_truth): correct_flag = False # 1. format - # 找到所有的 \\boxed{} 匹配项 + # Find all \\boxed{} matches box_matches = re.findall(r"\\boxed\{([^}]+)\}", response) - # 如果没有找到 \\boxed{} 则返回 None + # If no \\boxed{} found, return None if not box_matches: lower_response = response.lower() last_answer_index = lower_response.rfind("answer is") @@ -62,7 +62,7 @@ def single_choice_reward(response, ground_truth): extracted_answer = response else: extracted_answer = response[last_answer_index + 9 :] - # 获取最后一个 \\boxed{} 的内容 + # Get content of the last \\boxed{} else: format_flag = True extracted_answer = box_matches[-1] @@ -100,8 +100,8 @@ def single_choice_reward(response, ground_truth): class GeneralValRuleRewardWorker(Worker): """ - 一个示例 Reward Worker,用于执行 ifeval 验证并把每个 func 的结果放到 output.tensors 中。 - 在此示例里,ground_truths的str + Example Reward Worker for executing ifeval validation and putting each func result into output.tensors. + In this example, ground_truths string """ def __init__(self, worker_config: WorkerConfig): @@ -125,8 +125,8 @@ def compute_rewards(self, data: DataProto): tags = data.non_tensor_batch["tag"] scores = [] - format_values = [] # 格式正确的value(严格要求有\boxed{}) - correct_values = [] # 答案正确的value(用更宽松的规则提取) + format_values = [] # Format-correct values (strictly require \boxed{}) + correct_values = [] # Answer-correct values (use more lenient extraction rules) for i, (resp_tokens, ground_truth, tag, prompt) in enumerate( zip(data.batch["responses"], ground_truths, tags, prompts) @@ -141,13 +141,13 @@ def compute_rewards(self, data: DataProto): extracted_answer, reward, format_flag, correct_flag = single_choice_reward(answer_text, ground_truth) format_value = 1 if format_flag else 0 correct_value = 1 if correct_flag else 0 - # score应该为0或者1,标志模型回复的对错 + # score should be 0 or 1, indicating model response correctness if reward > 0: score = 1.0 else: score = 0.0 - # 存到 scores + # Store to scores scores.append(score) format_values.append(format_value) correct_values.append(correct_value) @@ -174,7 +174,7 @@ def compute_rewards(self, data: DataProto): token_level_rewards = torch.zeros_like(data.batch["responses"], dtype=torch.float16) response_level_rewards = torch.zeros_like(scores, dtype=torch.float16) - # 5) 将这些张量打包进同一个字典 + # 5) Pack these tensors into the same dictionary output_tensors = { "scores": scores, # "format_values": format_values, @@ -183,6 +183,6 @@ def compute_rewards(self, data: DataProto): "response_level_rewards": response_level_rewards, } - # 6) 用 DataProto.from_dict(...) 构造返回值 + # 6) Use DataProto.from_dict(...) to construct return value output = DataProto.from_dict(tensors=output_tensors) return output diff --git a/roll/pipeline/rlvr/rewards/ifeval_rule_reward_worker.py b/roll/pipeline/rlvr/rewards/ifeval_rule_reward_worker.py index f94ad9ea..edd43050 100644 --- a/roll/pipeline/rlvr/rewards/ifeval_rule_reward_worker.py +++ b/roll/pipeline/rlvr/rewards/ifeval_rule_reward_worker.py @@ -1,4 +1,4 @@ -# 导入必要的库和模块 +# Import necessary libraries and modules from functools import partial from typing import Optional, Union, Iterator import json @@ -15,7 +15,7 @@ import itertools, collections from collections import defaultdict -# 从已有的 WorkerConfig、Worker、Dispatch 等模块导入 +# Import from existing WorkerConfig, Worker, Dispatch and other modules from roll.configs.worker_config import WorkerConfig from roll.distributed.executor.worker import Worker from roll.distributed.scheduler.decorator import Dispatch, register @@ -31,40 +31,40 @@ from nltk.corpus import wordnet as wn from nltk.wsd import lesk -# 假设 tokenizer 依然来自 default_tokenizer_provider +# Assume tokenizer still comes from default_tokenizer_provider from roll.models.model_providers import default_reward_model_provider, default_tokenizer_provider -# 引入 ifeval 验证函数的字典映射 -# IF_FUNCTIONS_MAP 是题主在上面给出的完整实现中包含的函数映射 +# Import dictionary mapping of ifeval validation functions +# IF_FUNCTIONS_MAP is the function mapping included in the complete implementation given above from typing import Union, Dict, List from roll.utils.logging import get_logger -logger = get_logger() # 获取日志记录器实例 +logger = get_logger() # Get logger instance def first_boxed(text: str) -> str | None: """ - 提取第一个 \boxed{...} 的内容,支持 boxed 内部再嵌套 {}。 + Extract content of the first \boxed{...}, supporting nested {} inside boxed. """ marker = r"\boxed{" start = text.find(marker) if start == -1: - return "" # 没找到 \boxed{ + return "" # No \boxed{ found - i = start + len(marker) # 跳过 '\boxed{' - depth = 1 # 已进入 1 层 { + i = start + len(marker) # Skip '\boxed{' + depth = 1 # Already entered 1 level of { buf = [] - while i < len(text) and depth: # 扫描直到配平 + while i < len(text) and depth: # Scan until balanced ch = text[i] if ch == "{": depth += 1 elif ch == "}": depth -= 1 - if depth == 0: # 恢复到 0 说明 boxed 完成 + if depth == 0: # Return to 0 means boxed is complete break - if depth: # 只在括号未配平时记录字符 + if depth: # Only record characters when brackets are not balanced buf.append(ch) i += 1 @@ -73,8 +73,8 @@ def first_boxed(text: str) -> str | None: class timeout: """ - 与 MathRewardWorker 示例中类似的超时上下文,用于演示, - 如果不需要超时,可直接省略。 + Timeout context similar to MathRewardWorker example, for demonstration, + can be omitted directly if timeout is not needed. """ def __init__(self, seconds=1, error_message="Timeout"): @@ -92,97 +92,97 @@ def __exit__(self, type, value, traceback): signal.alarm(0) -# 包含关键字:在你的回答中应包含关键字 {keyword1}、{keyword2} +# Contains keywords: Your answer should include keywords {keyword1}, {keyword2} def verify_keywords(text, keyword_list): """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py """ - # 将响应文本转换为小写,以便进行不区分大小写的匹配 + # Convert response text to lowercase for case-insensitive matching response_lower = text.lower() - # 检查响应中是否包含所有关键字(每个关键字也转换为小写进行匹配) + # Check if response contains all keywords (each keyword also converted to lowercase for matching) return all(keyword.lower() in response_lower for keyword in keyword_list) -# 关键字出现频率:在你的回答中,单词 {word} 应该出现 {N} 次 +# Keyword frequency: In your answer, word {word} should appear {N} times def verify_keyword_frequency(text, word, N): """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py """ - # 将文本转换为小写,使搜索不区分大小写 + # Convert text to lowercase for case-insensitive search text = text.lower() word = word.lower() - # 使用正则表达式匹配单词边界,将文本切分为单词列表 + # Use regex to match word boundaries and split text into word list words = re.findall(r"\b\w+\b", text) - # 统计实际出现次数(精确匹配关键字) + # Count actual occurrences (exact match keywords) actual_count = sum(1 for word in words if word == word) - # 检查实际出现次数是否等于期望的 N 次 + # Check if actual occurrence count equals expected N times constraint_met = actual_count == N return constraint_met -# 禁止出现特定单词:回答中不应包含关键字 {forbidden words} +# Forbidden words: Answer should not contain keywords {forbidden words} def validate_forbidden_words(text, forbidden_words): """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py """ - # 将文本转换为小写,进行不区分大小写的匹配 + # Convert text to lowercase for case-insensitive matching text_lower = text.lower() - # 检查每个禁止单词是否出现在文本中 + # Check if each forbidden word appears in the text found_words = [word for word in forbidden_words if word.lower() in text_lower] - # 如果没有找到禁止单词,返回 True;否则返回 False + # If no forbidden words found, return True; otherwise return False return len(found_words) == 0 -# 字母出现频率:在你的回答中,字母 {letter} 应该恰好出现 {N} 次 +# Letter frequency: In your answer, letter {letter} should appear exactly {N} times def verify_letter_frequency(text: str, letter: str, N: int) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py """ if len(letter) != 1: - raise ValueError("字母参数必须为单个字符") + raise ValueError("Letter parameter must be a single character") actual_count = text.count(letter) return actual_count == N -# 回答语言约束:你的整个回答应当使用 {language},不允许包含其他语言内容 +# Response language constraint: Your entire answer should use {language}, no other language content allowed def validate_response_language(text, language): """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py """ from langdetect import detect - # 检测文本的语言 + # Detect text language detected_language = detect(text) - # 检查检测到的语言是否与预期语言相符 + # Check if detected language matches expected language return detected_language == language -# 段落数量:回答中应包含 {N} 个段落,段落之间使用 markdown 分隔符 "* * *" 隔开 +# Paragraph count: Answer should contain {N} paragraphs, separated by markdown separator "* * *" def verify_paragraph_count(text: str, N: int) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py """ def clean_text(text: str) -> str: - """移除多余空白字符,并规范换行符""" + """Remove excess whitespace and normalize line breaks""" return "\n".join(line.strip() for line in text.splitlines()).strip() - # 清理输入文本 + # Clean input text text = clean_text(text) - # 依据 markdown 分隔符分割文本,每个分隔符会创建 n+1 个段落 + # Split text by markdown separator, each separator creates n+1 paragraphs paragraphs = text.split("* * *") actual_count = len(paragraphs) - # 验证每个分割结果中是否包含非空内容 + # Verify each split result contains non-empty content valid_paragraphs = [p.strip() for p in paragraphs if p.strip()] if len(valid_paragraphs) != actual_count: return False @@ -190,16 +190,16 @@ def clean_text(text: str) -> str: return actual_count == N -# 单词数量约束:回答中的单词数应至少/大约/最多达到 {N} 个 +# Word count constraint: Answer word count should be at least/around/at most {N} def validate_word_constraint(text: str, N: int, quantifier: str) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py """ - # 清除多余空白字符并拆分文本为单词列表 + # Remove excess whitespace and split text into word list words = text.strip().split() actual_count = len(words) - # 定义 "around" 约束的容错范围(目标单词数的 ±10%,至少 1 个单词) + # Define tolerance range for "around" constraint (±10% of target word count, at least 1 word) tolerance = max(round(N * 0.1), 1) if quantifier == "at least": @@ -212,18 +212,18 @@ def validate_word_constraint(text: str, N: int, quantifier: str) -> bool: return False -# 句子数量约束:回答中应包含至少/大约/最多 {N} 个句子 +# Sentence count constraint: Answer should contain at least/around/at most {N} sentences def verify_sentence_constraint(text: str, N: int, quantifier: str) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py """ - # 使用正则表达式根据句号或问号后的空格拆分文本为句子列表 + # Use regex to split text into sentence list based on periods or question marks followed by spaces sentences = re.split(r"(?= N elif quantifier == "around": @@ -234,61 +234,61 @@ def verify_sentence_constraint(text: str, N: int, quantifier: str) -> bool: return False -# 段落数量及指定段落首词约束:回答中应包含 {N} 个段落,段落之间仅以两个换行符分隔,第 {i} 个段落必须以 {first word} 开头 +# Paragraph count and specific paragraph first word constraint: Answer should contain {N} paragraphs separated only by two newlines, paragraph {i} must start with {first word} def validate_paragraphs(text, N, first_word, i): """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py """ - # 根据两个换行符分割文本为段落 + # Split text into paragraphs by two newlines paragraphs = text.split("\n\n") - # 检查段落总数是否符合要求 + # Check if total paragraph count meets requirements if len(paragraphs) != N: return False - # 检查第 i 个段落的开头是否为指定单词 + # Check if paragraph i starts with specified word if paragraphs[i - 1].strip().startswith(first_word): return True return False -# 附言验证:请在回答末尾明确添加以 {postscript marker} 开头的附言 +# Postscript validation: Please clearly add a postscript starting with {postscript marker} at the end of your answer def verify_postscript(text, postscript_marker): """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py """ - # 检查文本中是否包含附言标记 + # Check if text contains postscript marker if postscript_marker in text: - # 获取标记的索引位置 + # Get marker index position marker_index = text.find(postscript_marker) - # 检查标记附近是否还有其它内容 + # Check if there's other content near the marker remaining_text = text[marker_index:].strip() - # 验证附言不只是标记本身而已 + # Verify postscript is not just the marker itself return len(remaining_text) > len(postscript_marker) return False -# 占位符验证:回答中应至少包含 {N} 个用方括号表示的占位符,例如 [address] +# Placeholder validation: Answer should contain at least {N} placeholders in square brackets, e.g. [address] def validate_placeholders(text: str, N: int) -> tuple[bool, List[str]]: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py """ - # 使用正则表达式查找所有位于方括号内的内容 + # Use regex to find all content within square brackets pattern = r"\[(.*?)\]" placeholders = re.findall(pattern, text) - # 检查是否至少找到了 N 个占位符 + # Check if at least N placeholders were found has_enough = len(placeholders) >= N return has_enough -# 项目符号验证:回答必须包含恰好 {N} 个项目符号点。请使用 markdown 格式的项目点,例如:* 这是一个点。 +# Bullet point validation: Answer must contain exactly {N} bullet points. Use markdown format bullet points, e.g.: * This is a point. def verify_bullet_points(text: str, N: int) -> tuple[bool, str]: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py """ - # 按行拆分文本,并统计以 * 或 - 开头的行 + # Split text by lines and count lines starting with * or - lines = text.split("\n") bullet_points = [line.strip() for line in lines if line.strip().startswith(("*", "-"))] actual_count = len(bullet_points) @@ -299,7 +299,7 @@ def verify_bullet_points(text: str, N: int) -> tuple[bool, str]: return False -# 标题验证:回答中必须包含一个标题,用双尖括号包裹,例如 <> +# Title validation: Answer must contain a title wrapped in double angle brackets, e.g. <> def validate_title(text: str) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py @@ -313,7 +313,7 @@ def validate_title(text: str) -> bool: return False -# 选择题验证:回答内容必须为以下选项之一:{options} +# Multiple choice validation: Answer content must be one of the following options: {options} def validate_choice(text: str, options: list) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py @@ -324,7 +324,7 @@ def validate_choice(text: str, options: list) -> bool: return False -# 高亮区域数量验证:回答中必须至少高亮 {N} 个区域,使用 markdown 格式,比如 *highlighted section* +# Highlighted section count validation: Answer must highlight at least {N} sections using markdown format, e.g. *highlighted section* def validate_highlighted_sections(text: str, N: int) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py @@ -338,13 +338,13 @@ def validate_highlighted_sections(text: str, N: int) -> bool: return False -# 多区块验证:回答中必须包含 {N} 个区块,每个区块的开始都应以 {section splitter} 开头 +# Multi-section validation: Answer must contain {N} sections, each section should start with {section splitter} def validate_sections(text: str, N: int, section_splitter: str) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py """ sections = text.split(section_splitter) - # 第一个区块可能不以分割符开头,因此需要做调整 + # First section may not start with splitter, so adjustment needed if sections[0] == "": sections.pop(0) if len(sections) == N: @@ -353,7 +353,7 @@ def validate_sections(text: str, N: int, section_splitter: str) -> bool: return False -# JSON 格式验证:整个输出必须使用 JSON 格式包裹 +# JSON format validation: Entire output must be wrapped in JSON format def validate_json_format(text: str) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py @@ -365,7 +365,7 @@ def validate_json_format(text: str) -> bool: return True -# 重复提示验证:首先重复用户的请求内容不做更改,然后再给出你的回答(重复内容不应包含其他额外信息) +# Repeat prompt validation: First repeat user's request content unchanged, then give your answer (repeated content should not contain other additional information) def validate_repeat_prompt(text: str, original_prompt: str) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py @@ -376,7 +376,7 @@ def validate_repeat_prompt(text: str, original_prompt: str) -> bool: return False -# 两种回答验证:提供两种不同的回答,两个回答之间仅用六个星号 "******" 分隔开 +# Two responses validation: Provide two different answers separated only by six asterisks "******" def validate_two_responses(text: str) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py @@ -390,7 +390,7 @@ def validate_two_responses(text: str) -> bool: return False -# 全部大写:整个回答必须全部使用英文大写字母 +# All uppercase: Entire answer must use English uppercase letters def validate_uppercase(text: str) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py @@ -401,7 +401,7 @@ def validate_uppercase(text: str) -> bool: return False -# 全部小写:整个回答必须全部使用英文小写字母,不允许有大写字母 +# All lowercase: Entire answer must use English lowercase letters, no uppercase allowed def validate_lowercase(text: str) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py @@ -412,7 +412,7 @@ def validate_lowercase(text: str) -> bool: return False -# 全大写单词出现频率验证:在回答中,全大写单词的出现次数应满足至少/大约/最多 {N} 次 +# All-caps word frequency validation: In the answer, all-caps words should appear at least/around/at most {N} times def validate_frequency_capital_words(text: str, N: int, quantifier: str) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py @@ -428,7 +428,7 @@ def validate_frequency_capital_words(text: str, N: int, quantifier: str) -> bool return False -# 结束语验证:回答最后必须以确切的短语 {end phrase} 结束,且该短语后面不允许有其他内容 +# End phrase validation: Answer must end with exact phrase {end phrase}, with no other content after it def validate_end(text: str, end_phrase: str) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py @@ -439,7 +439,7 @@ def validate_end(text: str, end_phrase: str) -> bool: return False -# 引号包装验证:整个回答必须用双引号包裹起来 +# Quotation wrapping validation: Entire answer must be wrapped in double quotes def validate_quotation(text: str) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py @@ -450,7 +450,7 @@ def validate_quotation(text: str) -> bool: return False -# 禁用逗号:整个回答中不允许出现任何逗号 +# Comma ban: No commas allowed in entire answer def validate_no_commas(text: str) -> bool: """ Reference implementation from: https://github.com/allenai/open-instruct/blob/main/open_instruct/if_functions.py @@ -463,23 +463,23 @@ def validate_no_commas(text: str) -> bool: def call_ifeval_function(func, text: str, constraint_dict: dict): """ - 1) 获取func的函数签名 - 2) 只保留与签名匹配且非None的参数 - 3) 调用func(text, **filtered_args) + 1) Get function signature of func + 2) Only keep parameters that match signature and are not None + 3) Call func(text, **filtered_args) """ - # 1) 获取函数签名 + # 1) Get function signature sig = inspect.signature(func) - valid_params = set(sig.parameters.keys()) # 该函数形参名的集合 + valid_params = set(sig.parameters.keys()) # Set of function parameter names - # 2) 过滤掉 constraint_dict 中的无关字段和 None 值 - # (如果一个函数的参数刚好是 None 值也是合法,就保留;否则你可以额外判断) + # 2) Filter out irrelevant fields and None values in constraint_dict + # (If a function parameter is None value but still valid, keep it; otherwise you can add extra judgment) filtered_args = {} for k, v in constraint_dict.items(): - if k in valid_params: # 形参里确实有这个字段 - # 如果你想彻底丢弃 None,也可以加上: if v is not None: + if k in valid_params: # Parameter list actually has this field + # If you want to completely discard None, you can add: if v is not None: filtered_args[k] = v - # 3) 调用函数 + # 3) Call function return func(text, **filtered_args) @@ -487,35 +487,35 @@ def get_repetition_penalty_reward(ngram_size: int, max_penalty: float): """ Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py """ - # 如果 max_penalty 是正的,这里直接抛出错误,说明要用负值来做惩罚 + # If max_penalty is positive, throw error directly, indicating negative values should be used for penalty if max_penalty > 0: raise ValueError(f"max_penalty {max_penalty} should not be positive") - # 内部函数 zipngram,用于切分文本为 ngram + # Internal function zipngram for splitting text into ngrams def zipngram(text: str, ngram_size: int): words = text.lower().split() return zip(*[words[i:] for i in range(ngram_size)]) - # repetition_penalty_reward 函数用于计算在给定 response 中,n-gram 的重复程度 + # repetition_penalty_reward function calculates n-gram repetition degree in given response def repetition_penalty_reward(response, **kwargs) -> float: """ ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py """ - # 如果回复为空或不足 ngram 大小,则直接返回 0 + # If response is empty or insufficient ngram size, return 0 directly if response == "" or len(response.split()) < ngram_size: return 0.0 - # 遍历所有 ngram,统计 unique ngram 和 total ngram 的数量 + # Iterate all ngrams, count unique ngram and total ngram quantities ngrams = set() total = 0 for ng in zipngram(response, ngram_size): ngrams.add(ng) total += 1 - # scaling = 1 - (不重复的 ngram / 总的 ngram 数量) - # 不重复的越少(重复越多)scaling 越大 + # scaling = 1 - (non-repeated ngrams / total ngram count) + # The fewer non-repeated (more repeated), the larger scaling scaling = 1 - len(ngrams) / total - # reward 是 scaling 乘以 max_penalty + # reward is scaling multiplied by max_penalty reward = scaling * max_penalty return reward @@ -524,30 +524,30 @@ def repetition_penalty_reward(response, **kwargs) -> float: def extract_after_last_think(input_string, end_think=""): """ - 提取输入字符串中最后一个 '' 标签之后的内容, - 并移除结果字符串开头的所有换行符。 + Extract content after the last '' tag in the input string, + and remove all newlines at the beginning of the result string. Args: - input_string: 原始字符串。 + input_string: Original string. Returns: - 提取并处理后的字符串。如果未找到 '' 标签,则返回空字符串。 + Extracted and processed string. Returns empty string if '' tag not found. """ - # 查找最后一个 end_think 的起始位置 + # Find starting position of last end_think last_index = input_string.rfind(end_think) - # 如果没有找到 end_think + # If end_think not found if last_index == -1: # return "" - return input_string # 或者根据需要返回 None 或原始字符串 + return input_string # Or return None or original string as needed - # 计算 end_think 结束后的位置 + # Calculate position after end_think ends start_pos = last_index + len(end_think) - # 提取 end_think 之后的部分 + # Extract part after end_think extracted_part = input_string[start_pos:] - # 移除开头的所有换行符 '\n' + # Remove all leading newlines '\n' cleaned_part = extracted_part.lstrip("\n") return cleaned_part @@ -555,8 +555,8 @@ def extract_after_last_think(input_string, end_think=""): class GeneralRuleRewardWorker(Worker): """ - 一个示例 Reward Worker,用于执行 ifeval 验证并把每个 func 的结果放到 output.tensors 中。 - 在此示例里,ground_truths的str + Example Reward Worker for executing ifeval validation and putting each func result into output.tensors. + In this example, ground_truths string """ def __init__(self, worker_config: WorkerConfig): @@ -577,20 +577,20 @@ def initialize(self, pipeline_config): @register(dispatch_mode=Dispatch.DP_MP_COMPUTE) def compute_rewards(self, data: DataProto): """ - 仅调用 data.non_tensor_batch['ground_truth'] 中的 “func_name”, - 并将其结果作为单一的 response-level 奖励返回。 + Only call "func_name" in data.non_tensor_batch['ground_truth'], + and return its result as single response-level reward. """ - # 1) 解码回复文本 + # 1) Decode response text response_text_list = self.tokenizer.batch_decode(data.batch["responses"], skip_special_tokens=False) batch_size = len(response_text_list) - # 2) 读取 ground_truth(其中是一串 JSON,包含 func_name 等参数) + # 2) Read ground_truth (which is a JSON string containing func_name and other parameters) prompts = data.non_tensor_batch["prompt"] ground_truths = data.non_tensor_batch["ground_truth"] tags = data.non_tensor_batch["tag"] - # 3) 准备一个列表存放验证结果 + # 3) Prepare a list to store validation results results = [0.0] * batch_size repetition_penalty_rewards = [] response_length_rewards = [] @@ -598,35 +598,35 @@ def compute_rewards(self, data: DataProto): for i, (resp_tokens, ground_truth, tag, prompt) in enumerate( zip(data.batch["responses"], ground_truths, tags, prompts) ): - # 解码当前条目 + # Decode current entry resp_text = self.tokenizer.decode(resp_tokens, skip_special_tokens=False) resp_text1 = resp_text.replace("<|endoftext|>", "").replace("", "").replace("<|im_end|>", "") resp_text = extract_after_last_think(resp_text1) # logger.info(f"extract_after_last_think(resp_text): {resp_text}") if tag == "ifeval": - # 解析 ground_truth (JSON) 得到约束信息 + # Parse ground_truth (JSON) to get constraint information if isinstance(ground_truth, str): constraint_dict = json.loads(ground_truth) else: - constraint_dict = ground_truth # 如果已经是 dict,就直接用 + constraint_dict = ground_truth # If already dict, use directly - # 从约束中取出 func_name + # Extract func_name from constraints func_name = constraint_dict.get("func_name", None) if not func_name or func_name not in IF_FUNCTIONS_MAP: self.logger.warning("constraint missing func_name") - # 如果无 func_name 或没找到对应函数 - # 那么这里我们将结果记为 0.0(也可做别的处理) + # If no func_name or corresponding function not found + # Then we record result as 0.0 (can do other processing) results[i] = 0.0 continue - # 移除 func_name,其它参数传给函数 + # Remove func_name, pass other parameters to function constraint_dict.pop("func_name") func = IF_FUNCTIONS_MAP[func_name] # print(f"Running function {func_name} with Response text: {resp_text}") # print(f"Response text: {resp_text}") - # 调用函数进行验证 + # Call function for validation try: result = call_ifeval_function(func, resp_text, constraint_dict) except Exception as e: @@ -635,7 +635,7 @@ def compute_rewards(self, data: DataProto): else: self.logger.warning(f"Unknown tag: {tag}") - # 将结果转为 float: bool -> (1.0/0.0), 数值 -> float(...), 其他结构 -> bool(...) + # Convert result to float: bool -> (1.0/0.0), numeric -> float(...), other structures -> bool(...) if isinstance(result, bool): val = 1.0 if result else 0.0 elif isinstance(result, (int, float)): @@ -643,27 +643,27 @@ def compute_rewards(self, data: DataProto): else: val = 1.0 if result else 0.0 - # 存到 results + # Store to results results[i] = val repetition_penalty_rewards.append(self.repetition_penalty_reward_fn(resp_text1)) - # 4) 准备输出张量: - # - token_level_rewards:形状与 responses 相同、全 0 - # - response_level_rewards:即 results - # - scores:可与 response_level_rewards 相同(用于统计/日志) + # 4) Prepare output tensors: + # - token_level_rewards: same shape as responses, all 0 + # - response_level_rewards: i.e. results + # - scores: can be same as response_level_rewards (for statistics/logging) token_level_rewards = torch.zeros_like(data.batch["responses"], dtype=torch.float16) scores = torch.tensor(results, dtype=torch.float16) repetition_penalty_rewards = torch.tensor(repetition_penalty_rewards, dtype=torch.float16) response_level_rewards = scores + repetition_penalty_rewards - # 5) 将这些张量打包进同一个字典 + # 5) Pack these tensors into the same dictionary output_tensors = { "token_level_rewards": token_level_rewards, "response_level_rewards": response_level_rewards, "scores": scores } - # 6) 用 DataProto.from_dict(...) 构造返回值 + # 6) Use DataProto.from_dict(...) to construct return value output = DataProto.from_dict(tensors=output_tensors) return output diff --git a/roll/pipeline/rlvr/rewards/llm_judge_reward_worker.py b/roll/pipeline/rlvr/rewards/llm_judge_reward_worker.py index e31ec774..fa8eaf99 100644 --- a/roll/pipeline/rlvr/rewards/llm_judge_reward_worker.py +++ b/roll/pipeline/rlvr/rewards/llm_judge_reward_worker.py @@ -34,7 +34,7 @@ def __init__(self, worker_config: WorkerConfig): self.tokenizer = None self.strategy: Optional[Union[InferenceStrategy, TrainStrategy]] = None - # LLM judge相关配置 + # LLM judge related configuration self.judge_prompt = self.worker_config.judge_prompt if hasattr(self.worker_config, "judge_prompt") else None self.judge_prompt = prompt_maps[self.judge_prompt] self.judge_model_type = ( diff --git a/roll/pipeline/rlvr/rewards/math_rule_reward_worker.py b/roll/pipeline/rlvr/rewards/math_rule_reward_worker.py index ca299f08..7db81236 100644 --- a/roll/pipeline/rlvr/rewards/math_rule_reward_worker.py +++ b/roll/pipeline/rlvr/rewards/math_rule_reward_worker.py @@ -48,8 +48,8 @@ def _hf_verify_math_sample(response, answer, output_queue): def hf_verify_math_sample(answer_a, answer_b, timeout_sec=5.0): """ - 在多进程中调用 hf math verify, - 以在超时时间内完不成时返回 False. + Call hf math verify in multiprocessing, + return False when timeout occurs. """ output_queue = multiprocessing.Queue() @@ -58,7 +58,7 @@ def hf_verify_math_sample(answer_a, answer_b, timeout_sec=5.0): p.join(timeout_sec) if p.is_alive(): - # 超时 -> 杀掉子进程, 返回 False + # Timeout -> kill subprocess, return False p.terminate() p.join() return False, "", "" @@ -118,8 +118,8 @@ def format_reward_fn(text: str, pattern: Optional[str] = r"^.*?.* class MathRuleRewardWorker(Worker): """ - (x)Reward Model 使用 AutoModelForSequenceClassification 协议 - 面向math的rule reward model + (x)Reward Model uses AutoModelForSequenceClassification protocol + Math-oriented rule reward model """ def __init__(self, worker_config: WorkerConfig): diff --git a/roll/pipeline/rlvr/rlvr_config.py b/roll/pipeline/rlvr/rlvr_config.py index 17cffec8..0da46e81 100644 --- a/roll/pipeline/rlvr/rlvr_config.py +++ b/roll/pipeline/rlvr/rlvr_config.py @@ -282,7 +282,7 @@ def set_max_steps(self, max_steps: int): self.critic.training_args.per_device_train_batch_size * self.critic.training_args.gradient_accumulation_steps ) - # 没有除dp_size,需要在分布式环境初始化后再除 + # Not divided by dp_size, need to divide after distributed environment initialization self.actor_train.training_args.max_steps = max_steps * ( self.rollout_batch_size * self.actor_infer.generating_args.num_return_sequences diff --git a/roll/pipeline/rlvr/rlvr_pipeline.py b/roll/pipeline/rlvr/rlvr_pipeline.py index 6011f014..eae67ff7 100644 --- a/roll/pipeline/rlvr/rlvr_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_pipeline.py @@ -37,7 +37,7 @@ def preprocess_dataset(dataset, prompt_len, encode_function, num_proc): - # 处理数据 + # Process data print(f"Begin : {dataset}") dataset = dataset.map( encode_function, @@ -46,7 +46,7 @@ def preprocess_dataset(dataset, prompt_len, encode_function, num_proc): desc="Encoding dataset", load_from_cache_file=False, ) - # 过滤cutoff + # Filter cutoff dataset = dataset.filter( lambda data_i: 5 < len(data_i["input_ids"]) <= prompt_len, num_proc=num_proc, @@ -83,7 +83,7 @@ def update_dataset_domain(tag_2_domain: Dict[str, set[str]], row): def query_filter_fn(data_list: List[DataProto], config: RLVRConfig) -> bool: """ - 各domain的过滤规则可以自定义 + Custom filtering rules for each domain """ response_level_rewards = [data.batch["response_level_rewards"] for data in data_list] if len(response_level_rewards) == 1: @@ -127,7 +127,7 @@ def __init__(self, pipeline_config: RLVRConfig): val_dataset_paths = self.pipeline_config.validation.data_args.file_name self.val_dataset = datasets.load_dataset("json", data_files=val_dataset_paths)["train"] - # 加上format,然后转ids的func + # Add format, then convert to ids function template_name = ( self.pipeline_config.global_template if self.pipeline_config.global_template @@ -311,9 +311,9 @@ def __init__(self, pipeline_config: RLVRConfig): @torch.no_grad() def run(self): - # 计算tokens per second 系统吞吐 + # Calculate tokens per second system throughput - # 创建一个专门管理监控指标的类 + # Create a specialized class for managing monitoring metrics metrics_mgr = MetricsManager() tps_timer = _Timer(window_size=5) @@ -330,7 +330,7 @@ def run(self): metrics_mgr.clear_metrics() with tps_timer, Timer(name="step_total", logger=None) as step_total_timer: - # 先model update,resume时不需要保存infer cluster的状态 + # First model update, no need to save infer cluster state during resume if self.pipeline_config.adv_estimator == "gae": self.critic.offload_states(blocking=True) self.actor_train.offload_states(blocking=True) @@ -349,7 +349,7 @@ def run(self): batch: DataProto = DataProto() batch.meta_info = {"global_step": global_step} - # 要按domain group by生成对应的batch + # Generate corresponding batches grouped by domain with actor_infer_timer, actor_infer_response_timer, Timer( name="step_generate", logger=None ) as step_generate_timer: @@ -409,18 +409,18 @@ def run(self): metrics_mgr.add_reduced_metrics(old_log_probs.meta_info.pop("metrics", {})) metrics_mgr.add_metric("time/old_log_probs", cal_old_logpb_timer.last) - # 要按domain group by处理reward + # Process rewards grouped by domain batch.batch["prompt_id"] = torch.arange(batch.batch.batch_size[0], device=batch.batch.device) batch_grouped: Dict[str, DataProto] = batch.group_by("domain") batch_list = [] for domain, domain_batch in batch_grouped.items(): - # 1. 处理mask相关策略, 获取sample level mask + # 1. Process mask related strategies, get sample level mask with Timer(name="get_sample_level_mask", logger=None) as get_sample_level_mask_timer: domain_batch, mask_metrics = get_sample_level_mask(domain_batch, self.pipeline_config) metrics_mgr.add_metrics(mask_metrics) metrics_mgr.add_metric("time/get_sample_level_mask", get_sample_level_mask_timer.last) - # 2. 处理reward相关策略 + # 2. Process reward related strategies with Timer(name="reward_postprocess", logger=None) as reward_postprocess_timer: domain_batch, response_level_metrics = reward_postprocess( domain_batch, self.pipeline_config, self.running @@ -428,7 +428,7 @@ def run(self): metrics_mgr.add_metrics(response_level_metrics) metrics_mgr.add_metric("time/reward_postprocess", reward_postprocess_timer.last) - # 3. 计算token level rewards + # 3. Calculate token level rewards with Timer(name="get_token_reward", logger=None) as get_token_reward_timer: domain_batch, token_level_metrics = compute_token_reward( domain_batch, self.pipeline_config, self.kl_ctrl @@ -436,7 +436,7 @@ def run(self): metrics_mgr.add_metrics(token_level_metrics) metrics_mgr.add_metric("time/get_token_reward", get_token_reward_timer.last) - # 4. 计算advantage + # 4. Calculate advantage final_response_mask = domain_batch.batch["final_response_mask"].clone() with Timer(name="compute_advantage", logger=None) as compute_advantage_timer: domain_batch = compute_advantage( diff --git a/roll/third_party/deepspeed/offload_states_patch.py b/roll/third_party/deepspeed/offload_states_patch.py index 62f32092..de38aa79 100644 --- a/roll/third_party/deepspeed/offload_states_patch.py +++ b/roll/third_party/deepspeed/offload_states_patch.py @@ -180,7 +180,7 @@ def stage_1_and_2_offload_states(self: DeepSpeedZeroOptimizer, # LP param if needs_offload(OffloadStateTypeEnum.lp_params, include, self.offloaded_states) and not self.bit16_groups_flat[ 0].is_cpu: - # NOTE: 这里只支持offload optimizer 里的参数部分 + # NOTE: Only supports offloading parameters in optimizer here if pin_memory: if not hasattr(self, "lp_params_pin_buffers"): self.lp_params_pin_buffers = [ @@ -200,7 +200,7 @@ def stage_1_and_2_offload_states(self: DeepSpeedZeroOptimizer, self.offloaded_states.add(OffloadStateTypeEnum.lp_params) # LP grad - # NOTE: 这里好像没有 grad 缓存 + # NOTE: There seems to be no grad cache here if needs_offload(OffloadStateTypeEnum.lp_grads, include, self.offloaded_states): pass @@ -231,7 +231,7 @@ def stage_1_and_2_offload_states(self: DeepSpeedZeroOptimizer, offload_adam_states(self.optimizer, device, pin_memory=pin_memory, non_blocking=non_blocking) self.offloaded_states.add(OffloadStateTypeEnum.optim_states) - # NOTE: 清理额外引用,hp_mapping里包含了一份对全部flat tensor的引用 + # NOTE: Clean up extra references, hp_mapping contains a reference to all flat tensors for group in self.bit16_groups: for param in group: param._hp_mapping = None @@ -286,7 +286,7 @@ def stage_1_and_2_reload_states(self: DeepSpeedZeroOptimizer, include=None, non_ reload_adam_states(self.optimizer, device, non_blocking=non_blocking) self.offloaded_states.remove(OffloadStateTypeEnum.optim_states) - # NOTE: 恢复link + # NOTE: Restore link for group in self.bit16_groups: for param in group: param._hp_mapping = None @@ -329,9 +329,9 @@ def stage_3_offload_states(self: DeepSpeedZeroOptimizer_Stage3, if not hasattr(self, "lp_param_contiguous_pin_buffer"): self.lp_param_contiguous_pin_buffer = get_accelerator().pin_memory( torch.empty_like(self.lp_param_buffer, device=device)) - # NOTE: lp_param_buffer保存了由optimizer里取到的参数顺序 - # offload的时候先将 lp_param_buffer.cpu() - # 然后将tensor.data cp给model 的tensor.data,这一步也会有顺序不一致问题 + # NOTE: lp_param_buffer stores parameter order retrieved from optimizer + # When offloading, first move lp_param_buffer.cpu() + # Then copy tensor.data to model's tensor.data, this step may also have order inconsistency issues self.lp_param_contiguous_pin_buffer.copy_(self.lp_param_buffer, non_blocking=non_blocking) cpu_buffer = self.lp_param_contiguous_pin_buffer else: @@ -359,7 +359,7 @@ def stage_3_offload_states(self: DeepSpeedZeroOptimizer_Stage3, else: self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(device) self.averaged_gradients = {} - # NOTE: self.__param_id_to_grad_partition里存了一份对grad_partitions_flat_buffer的引用,patch修改需要使用名称修饰 + # NOTE: self.__param_id_to_grad_partition stores a reference to grad_partitions_flat_buffer, patch modifications need to use name mangling setattr(self, "_DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition", {}) self.offloaded_states.add(OffloadStateTypeEnum.lp_grads) @@ -398,8 +398,8 @@ def stage_3_reload_states(self: DeepSpeedZeroOptimizer_Stage3, include=None, non self.lp_param_buffer.data = cpu_buffer.data.to(device, non_blocking=non_blocking) self._set_fp16_partitioned_groups_flat() - # NOTE: 这里遍历的是self.module.parameters(), 而lp_param_buffer里的是fp16 group里取到的,这里参数的顺序不一致 - # 这里[p.ds_tensor for p in self.module.parameters()]需要按self.fp16_groups的顺序reorder一下 + # NOTE: Here we iterate over self.module.parameters(), while lp_param_buffer contains parameters from fp16 groups, the parameter order is inconsistent here + # Here [p.ds_tensor for p in self.module.parameters()] needs to be reordered according to self.fp16_groups order parameter_partitions: List[Tensor] = [param.ds_tensor for sub_group in self.fp16_groups for param in sub_group] for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(parameter_partitions): tensor.data = self.lp_param_buffer.data.narrow(0, offset, tensor_numel) @@ -450,7 +450,7 @@ def parameter_offload_offload_states(self: DeepSpeedZeRoOffload, self.offloaded_states = getattr(self, "offloaded_states", set()) if needs_offload(OffloadStateTypeEnum.lp_params, include, self.offloaded_states): - # NOTE: 这里不会执行了non_trainable_params都在engine里处理了 + # NOTE: This won't execute since non_trainable_params are all handled in the engine if not hasattr(self, "trainable_params"): self.trainable_params = [param.ds_tensor for param in self.module.parameters() if param.requires_grad] if len(self.trainable_params) == 0: diff --git a/roll/third_party/megatron/offload_states_patch.py b/roll/third_party/megatron/offload_states_patch.py index e204be22..0ac50687 100644 --- a/roll/third_party/megatron/offload_states_patch.py +++ b/roll/third_party/megatron/offload_states_patch.py @@ -1,11 +1,11 @@ """ -megatron offload states的实现思路: +Megatron offload states implementation approach: offload -释放megatron.core.distributed.distributed_data_parallel.DistributedDataParallel中的buffer -offload optimizer中的main_weights, main_weights.to('cpu'),使用flat tensor -offload optimizer states, to('cpu') -offload model weights, to('cpu'), 使用flat tensor;释放shard_float16_groups和shard_fp32_groups +Release buffers in megatron.core.distributed.distributed_data_parallel.DistributedDataParallel +Offload main_weights in optimizer, main_weights.to('cpu'), using flat tensor +Offload optimizer states, to('cpu') +Offload model weights, to('cpu'), using flat tensor; release shard_float16_groups and shard_fp32_groups reload @@ -221,10 +221,10 @@ def move_ddp_model_params_tensor_to_device(optimizer: DistributedOptimizer, # Clone model -> main. shard_model_param = model_param.detach().view(-1)[param_range.start: param_range.end] - # shard_float16_groups 不属于optimizer state的key,可以直接替换param - # 这种方式: optimizer.shard_float16_groups[ + # shard_float16_groups is not an optimizer state key, can directly replace param + # This approach: optimizer.shard_float16_groups[ # group_index][len(shard_float16_params_this_group)].data = shard_model_param - # 不能实现显存释放,定位到是model_param.detach()的影响,下面的fp32能正常释放 + # Cannot achieve memory release, identified as the effect of model_param.detach(), fp32 below can be released normally optimizer.shard_float16_groups[group_index][ len(shard_float16_params_this_group)] = shard_model_param shard_float16_params_this_group.append(shard_model_param) @@ -252,7 +252,7 @@ def move_grad_data_to_device(optimizer, # else: # buffer.grad_data.data = buffer.grad_data.data.to(device, non_blocking=non_blocking) - # 释放grad, 节省cpu memory + # Release grad, save CPU memory if device == torch.device('cpu'): buffer.grad_data.data = torch.tensor(1, dtype=buffer.grad_data.data.dtype, device=device, pin_memory=pin_memory) for param in buffer.params[::-1]: @@ -438,7 +438,7 @@ def offload_megatron_no_grad_module(model_chunks: List[Union[DistributedDataPara non_blocking: bool = False ): """ - 需要offload一下 grad=False的参数 + Need to offload parameters with grad=False """ device = torch.device('cpu') diff --git a/roll/third_party/sglang/v043post4_patch/async_engine.py b/roll/third_party/sglang/v043post4_patch/async_engine.py index 096b069e..e28de453 100644 --- a/roll/third_party/sglang/v043post4_patch/async_engine.py +++ b/roll/third_party/sglang/v043post4_patch/async_engine.py @@ -12,7 +12,7 @@ class SglangInputType(enum.Enum): ABORT = enum.auto() def list_endswith(lst, suffix): - # 检查 lst 是否以 suffix 结尾 + # Check if lst ends with suffix return lst[-len(suffix):] == suffix if len(suffix) <= len(lst) else False def trim_overlap_tokens(existing_tokens, new_chunk_tokens): @@ -28,7 +28,7 @@ def trim_overlap_tokens(existing_tokens, new_chunk_tokens): return new_chunk_tokens[max_overlap:] -# 用于存放所有abort_rid_set +# Used to store all abort_rid_set abort_rid_set = set() abort_lock = asyncio.Lock() @@ -38,7 +38,7 @@ async def producer(thread_queue, asyncio_queue): while True: if not thread_queue.empty(): data = thread_queue.get() - # 收到结束标记 + # Received end marker if data is None: logger.info("[sglang async engine] receive stop signal, stoping") break @@ -57,12 +57,12 @@ async def consumer(asyncio_queue, consumer_id, llm, request_complete_callback): from roll.distributed.scheduler.protocol import DataProto def process_sglang_output(token_ids, meta_info): - # 线上正式使用 + # Online production use output_data = DataProto(meta_info=meta_info) output_data.meta_info["output_token_ids"] = token_ids request_complete_callback(data=output_data) - # 本地调试使用 + # Local debugging use # request_complete_callback(meta_info['request_id'], token_ids) logger.debug(f"worker_id:{consumer_id} request_id: {meta_info['request_id']} finish!") diff --git a/roll/third_party/sglang/v046post4_patch/async_engine.py b/roll/third_party/sglang/v046post4_patch/async_engine.py index 096b069e..e28de453 100644 --- a/roll/third_party/sglang/v046post4_patch/async_engine.py +++ b/roll/third_party/sglang/v046post4_patch/async_engine.py @@ -12,7 +12,7 @@ class SglangInputType(enum.Enum): ABORT = enum.auto() def list_endswith(lst, suffix): - # 检查 lst 是否以 suffix 结尾 + # Check if lst ends with suffix return lst[-len(suffix):] == suffix if len(suffix) <= len(lst) else False def trim_overlap_tokens(existing_tokens, new_chunk_tokens): @@ -28,7 +28,7 @@ def trim_overlap_tokens(existing_tokens, new_chunk_tokens): return new_chunk_tokens[max_overlap:] -# 用于存放所有abort_rid_set +# Used to store all abort_rid_set abort_rid_set = set() abort_lock = asyncio.Lock() @@ -38,7 +38,7 @@ async def producer(thread_queue, asyncio_queue): while True: if not thread_queue.empty(): data = thread_queue.get() - # 收到结束标记 + # Received end marker if data is None: logger.info("[sglang async engine] receive stop signal, stoping") break @@ -57,12 +57,12 @@ async def consumer(asyncio_queue, consumer_id, llm, request_complete_callback): from roll.distributed.scheduler.protocol import DataProto def process_sglang_output(token_ids, meta_info): - # 线上正式使用 + # Online production use output_data = DataProto(meta_info=meta_info) output_data.meta_info["output_token_ids"] = token_ids request_complete_callback(data=output_data) - # 本地调试使用 + # Local debugging use # request_complete_callback(meta_info['request_id'], token_ids) logger.debug(f"worker_id:{consumer_id} request_id: {meta_info['request_id']} finish!") diff --git a/roll/third_party/vllm/vllm_0_7_3/llm.py b/roll/third_party/vllm/vllm_0_7_3/llm.py index 1c99ab70..0b4c349e 100644 --- a/roll/third_party/vllm/vllm_0_7_3/llm.py +++ b/roll/third_party/vllm/vllm_0_7_3/llm.py @@ -147,7 +147,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: def clear_unfinished_requests(self): self._run_engine(use_tqdm=True) - # 参数同步接口 + # Parameter synchronization interface def setup_collective_group(self, *args, **kwargs): self.collective_rpc(method="setup_collective_group", args=args, kwargs=kwargs) diff --git a/roll/third_party/vllm/vllm_0_7_3/llm_engine.py b/roll/third_party/vllm/vllm_0_7_3/llm_engine.py index ded9ec38..63fa5502 100644 --- a/roll/third_party/vllm/vllm_0_7_3/llm_engine.py +++ b/roll/third_party/vllm/vllm_0_7_3/llm_engine.py @@ -61,7 +61,7 @@ def update_worker_cls_config(cls, vllm_config: VllmConfig) -> None: parallel_config.worker_cls = \ "vllm.worker.multi_step_worker.MultiStepWorker" elif vllm_config.speculative_config: - # TODO: 投机采样 + # TODO: Speculative sampling if envs.VLLM_USE_V1: parallel_config.worker_cls = \ "vllm.v1.worker.gpu_worker.Worker" @@ -72,7 +72,7 @@ def update_worker_cls_config(cls, vllm_config: VllmConfig) -> None: "vllm.worker.worker.Worker" else: if envs.VLLM_USE_V1: - # TODO: 实现v1 + # TODO: Implement v1 parallel_config.worker_cls = \ "vllm.v1.worker.gpu_worker.Worker" else: diff --git a/roll/third_party/vllm/vllm_0_8_4/llm.py b/roll/third_party/vllm/vllm_0_8_4/llm.py index c43bc829..eab40b60 100644 --- a/roll/third_party/vllm/vllm_0_8_4/llm.py +++ b/roll/third_party/vllm/vllm_0_8_4/llm.py @@ -166,7 +166,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: def clear_unfinished_requests(self): self._run_engine(use_tqdm=True) - # 参数同步接口 + # Parameter synchronization interface def setup_collective_group(self, *args, **kwargs): self.collective_rpc(method="setup_collective_group", args=args, kwargs=kwargs) diff --git a/roll/third_party/vllm/vllm_0_8_4/v1/async_llm.py b/roll/third_party/vllm/vllm_0_8_4/v1/async_llm.py index e063da92..72f3d856 100644 --- a/roll/third_party/vllm/vllm_0_8_4/v1/async_llm.py +++ b/roll/third_party/vllm/vllm_0_8_4/v1/async_llm.py @@ -74,7 +74,7 @@ def offload_states(self, level=2): self.reset_prefix_cache() self.collective_rpc(method="offload_states") - # 参数同步接口 + # Parameter synchronization interface def setup_collective_group(self, *args, **kwargs): self.collective_rpc(method="setup_collective_group", args=args, kwargs=kwargs) diff --git a/roll/utils/functionals.py b/roll/utils/functionals.py index e95a17cd..46130a11 100644 --- a/roll/utils/functionals.py +++ b/roll/utils/functionals.py @@ -40,8 +40,8 @@ def delete_tensor_grad_visitor(obj, path): def traverse_obj(value, visitor, path=()): """ - 遍历对象的所有属性,包括属性的属性,找到所有的 Tensor。 - :param value: 任意 Python 对象 + Traverse all attributes of an object, including attributes of attributes, to find all Tensors. + :param value: Any Python object :visitor :path """ @@ -86,7 +86,7 @@ def divide_by_chunk_size( data: Union[np.ndarray, TensorDict], chunk_sizes: List[int] ) -> List[Union[np.ndarray, TensorDict]]: """ - 将numpy数组按照chunks的大小切分 + Split numpy array according to chunk sizes """ if not isinstance(data, (np.ndarray, TensorDict)): raise TypeError("Input 'array' must be a numpy ndarray or a TensorDict.") @@ -314,7 +314,7 @@ def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = T def response_level_masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True): """Whiten values with masked values.""" - # 考虑response的影响? + # Consider the impact of response? mean = masked_mean(values, mask, dim=-1) var = masked_var(mean, mask) mean = mean.mean() @@ -475,7 +475,7 @@ def compute_token_reward(data: "DataProto", pipeline_config: RLVRConfig, kl_ctrl action_mask=data.batch["response_mask"][:, 1:], kl_penalty=pipeline_config.kl_penalty, ) - # 是否添加token level kl + # Whether to add token level kl if pipeline_config.add_token_level_kl and "ref_log_probs" in data.batch.keys(): beta = kl_ctrl.value token_level_rewards = token_level_rewards - beta * kld @@ -505,8 +505,8 @@ def compute_token_reward(data: "DataProto", pipeline_config: RLVRConfig, kl_ctrl def reward_postprocess(data: "DataProto", pipeline_config: RLVRConfig, running_ctrl): response_level_rewards = data.batch["response_level_rewards"].clone().detach() response_level_metrics = {"critic/reward_clip_frac": 0.0} - # 对reward进行处理: 可以选择不同的normalization方法 - # 使用group-based normalization (按prompt分组) + # Process rewards: can choose different normalization methods + # Use group-based normalization (group by prompt) if pipeline_config.adv_estimator == "grpo" or pipeline_config.reward_norm == "group": if pipeline_config.reward_shift: response_level_rewards = group_reward_norm( @@ -521,14 +521,14 @@ def reward_postprocess(data: "DataProto", pipeline_config: RLVRConfig, running_c div_std=True, ) - # 使用batch-based normalization (整个batch) + # Use batch-based normalization (entire batch) elif pipeline_config.reward_norm == "batch": if hasattr(pipeline_config, "reward_shift") and pipeline_config.reward_shift: response_level_rewards = batch_reward_norm(response_level_rewards, div_std=False) else: response_level_rewards = batch_reward_norm(response_level_rewards, div_std=True) - # 使用running statistics进行normalization + # Use running statistics for normalization elif pipeline_config.reward_norm == "running": running = running_ctrl["domain"] running.update(response_level_rewards) @@ -541,7 +541,7 @@ def reward_postprocess(data: "DataProto", pipeline_config: RLVRConfig, running_c else: response_level_rewards = (response_level_rewards - mean) / std - # 对reward进行clip + # Clip reward if pipeline_config.reward_clip: reward_clip_frac = compute_clip_fraction( values=response_level_rewards, clip_max=pipeline_config.reward_clip, clip_min=-pipeline_config.reward_clip @@ -561,7 +561,7 @@ def get_sample_level_mask(data: "DataProto", pipeline_config: RLVRConfig): batch_size = data.batch["response_mask"].size(0) mask_metrics = {} - # mask相关策略 + # Mask related strategies data.batch["origin_response_mask"] = data.batch["response_mask"].clone() response_mask = data.batch["response_mask"][:, 1:].clone() true_response_length = response_mask.sum(-1).float() @@ -569,7 +569,7 @@ def get_sample_level_mask(data: "DataProto", pipeline_config: RLVRConfig): final_sample_mask = torch.ones(batch_size, device=response_mask.device) - # 1. max_len_mask: 过滤掉超过最大长度的样本 + # 1. max_len_mask: Filter out samples exceeding maximum length if pipeline_config.max_len_mask: max_len_mask = (max_response_length != true_response_length).float() final_sample_mask = final_sample_mask * max_len_mask @@ -577,7 +577,7 @@ def get_sample_level_mask(data: "DataProto", pipeline_config: RLVRConfig): else: mask_metrics["actor/max_len_mask_ratio"] = 1.0 - # 2. difficulty_mask: 基于难度的过滤 + # 2. difficulty_mask: Filter based on difficulty if pipeline_config.difficulty_mask: data = difficulty_mask( data, @@ -594,7 +594,7 @@ def get_sample_level_mask(data: "DataProto", pipeline_config: RLVRConfig): else: mask_metrics["actor/difficulty_mask_ratio"] = 1.0 - # 3. error_max_len_clip: 基于错误和长度的过滤 + # 3. error_max_len_clip: Filter based on errors and length if pipeline_config.error_max_len_clip: scores = data.batch["scores"] error_len_mask = ((scores == 0) & (true_response_length < pipeline_config.error_max_len_threshold)) | ( @@ -692,7 +692,7 @@ def compute_advantage( data.batch["raw_advantages"] = advantages if whiten_advantages: - # TODO whiten过程中是否要考虑response的长度? + # TODO Should response length be considered during whitening process? advantages = masked_whiten(values=advantages, mask=response_mask) advantages = advantages * response_mask @@ -725,8 +725,8 @@ def postprocess_generate( from roll.distributed.scheduler.protocol import DataProto if fill_eos_token: - # yali: 如果output最后一个token不是pad_token_id,则替换成eos_token_id, - # TODO: 需要消融这个变化的影响 + # yali: If the last token of output is not pad_token_id, replace it with eos_token_id, + # TODO: Need to ablate the impact of this change last_token_index = output.size(1) - 1 need_replace_mask = output[:, last_token_index] != pad_token_id output[need_replace_mask, last_token_index] = eos_token_id diff --git a/tests/distributed/strategy/generate/generate_pipeline.py b/tests/distributed/strategy/generate/generate_pipeline.py index 2b7e0283..f7191f34 100644 --- a/tests/distributed/strategy/generate/generate_pipeline.py +++ b/tests/distributed/strategy/generate/generate_pipeline.py @@ -62,7 +62,7 @@ def __init__(self, pipeline_config: RLVRConfig): def run(self): global_step = 0 - # 计算tokens per second 系统吞吐 + # Calculate tokens per second system throughput tps_timer = _Timer(window_size=5) metric_list = [] @@ -161,7 +161,7 @@ def __init__(self, pipeline_config: RLVRConfig): def run(self): global_step = 0 - # 计算tokens per second 系统吞吐 + # Calculate tokens per second system throughput tps_timer = _Timer(window_size=5) metric_list = [] diff --git a/tests/math/test_math_dataset.py b/tests/math/test_math_dataset.py index a45d49dc..e413f87a 100644 --- a/tests/math/test_math_dataset.py +++ b/tests/math/test_math_dataset.py @@ -12,7 +12,7 @@ dataset = load_dataset("json", data_files=dataset_path)["train"] -# 加上format,然后转ids的func +# Add format, then convert to ids function def encode_function(data_i): text_list = [] for instruct in data_i["prompt"]: @@ -26,11 +26,11 @@ def encode_function(data_i): return encodings -# 处理数据 +# Process data print(dataset) dataset = dataset.map(encode_function, batched=True, desc="Encoding dataset") print(dataset) -# 过滤cutoff +# Filter cutoff dataset = dataset.filter(lambda data_i: len(data_i["input_ids"]) <= 512, desc="Filtering dataset") print(dataset) # ------ diff --git a/tests/models/cuda_mem/test_large_gemm.py b/tests/models/cuda_mem/test_large_gemm.py index 64421a67..a1c2a5df 100644 --- a/tests/models/cuda_mem/test_large_gemm.py +++ b/tests/models/cuda_mem/test_large_gemm.py @@ -1,34 +1,34 @@ import torch import torch.nn as nn -# 定义参数 -vocab_size = 128 * 1000 # 词汇表大小 -intermediate_size = 1560 # 中间隐藏层大小 +# Define parameters +vocab_size = 128 * 1000 # Vocabulary size +intermediate_size = 1560 # Intermediate hidden layer size batch_size = 1 -seq_len = 4096 # 假设序列长度 +seq_len = 4096 # Assumed sequence length -# 创建一个随机隐层输出张量 (batch_size, seq_len, intermediate_size) +# Create a random hidden layer output tensor (batch_size, seq_len, intermediate_size) hidden_output = torch.randn(batch_size, seq_len, intermediate_size).cuda() -# 创建最后一层线性层,将隐层大小映射到词汇表大小 +# Create the last linear layer, mapping hidden layer size to vocabulary size # linear_layer = nn.Linear(intermediate_size, vocab_size).cuda() -# 对每个时间步进行最后一层的计算,可以使用 reshape -# 如果需要按照 seq_len 进行计算,可以使用 batched 为 (batch_size * seq_len, intermediate_size) -# logits = linear_layer(hidden_output.view(-1, intermediate_size)) # 变形为 (batch_size * seq_len, intermediate_size) +# Compute the last layer for each time step, can use reshape +# If computing by seq_len, can use batched as (batch_size * seq_len, intermediate_size) +# logits = linear_layer(hidden_output.view(-1, intermediate_size)) # Reshape to (batch_size * seq_len, intermediate_size) -# 直接构造权重矩阵 W (intermediate_size, vocab_size) +# Directly construct weight matrix W (intermediate_size, vocab_size) weight_matrix = torch.randn(intermediate_size, vocab_size).cuda() -# 计算 logits,进行矩阵乘法 -# 对每个时间步进行最后一层的计算,可以使用 reshape +# Compute logits, perform matrix multiplication +# Compute the last layer for each time step, can use reshape logits = torch.matmul(hidden_output.view(-1, intermediate_size), weight_matrix) -# 重新调整 logits 的形状为 (batch_size, seq_len, vocab_size) +# Reshape logits to (batch_size, seq_len, vocab_size) logits = logits.view(batch_size, seq_len, vocab_size) -# 计算 softmax 以得到概率分布 -probabilities = nn.functional.softmax(logits, dim=-1) # 计算每个时间步的概率分布 +# Compute softmax to get probability distribution +probabilities = nn.functional.softmax(logits, dim=-1) # Compute probability distribution for each time step del logits, probabilities, weight_matrix, hidden_output torch.cuda.empty_cache() diff --git a/tests/third_party/megatron/test_offload_states.py b/tests/third_party/megatron/test_offload_states.py index fa148669..35b56055 100644 --- a/tests/third_party/megatron/test_offload_states.py +++ b/tests/third_party/megatron/test_offload_states.py @@ -210,7 +210,7 @@ def create_mca_model(self): """ torchrun --standalone --nnodes=1 --nproc-per-node=2 -m pytest -s tests/third_party/megatron/test_offload_states.py --s 显示stdout/err +-s displays_stdout/err """ @@ -881,7 +881,7 @@ def run_model_fp32_optimizer(mca_model: TurboModelCreator, included_state, pin_m @pytest.mark.parametrize("optimizer_type", [None, "dist_optimizer", "fp16", "fp32"]) def test_megatron_offload_states(included_state, pin_memory, non_blocking, optimizer_type): """ - 有四块非optimizer的显存未释放: + There are four blocks of non-optimizer GPU memory not released: /opt/conda/envs/python3.10.13/lib/python3.10/site-packages/transformer_engine/pytorch/module/base.py:58:get_workspace /root/.local/lib/python3.10/site-packages/megatron/core/tensor_parallel/layers.py:413:forward /root/.local/lib/python3.10/site-packages/megatron/core/models/gpt/gpt_model.py:249:forward