diff --git a/.gitmodules b/.gitmodules index c9d6fc80..f692b8de 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,7 @@ [submodule "third_party/webshop-minimal"] path = third_party/webshop-minimal url = https://github.com/ZihanWang314/webshop-minimal.git +[submodule "third_party/vllm"] + path = third_party/vllm + url = https://github.com/taoluo/vllm.git + branch = roll diff --git a/data/test_interrupt.jsonl b/data/test_interrupt.jsonl new file mode 100644 index 00000000..28735428 --- /dev/null +++ b/data/test_interrupt.jsonl @@ -0,0 +1 @@ +{"id": "1", "source": "deepmath_103k", "difficulty": "4.5", "prompt": "You are a senior systems researcher. Draft a 4,000-word white paper titled “post-training systems for RLHF” with these sections: abstract (≤150 words), introduction, related work, system architecture (with numbered sub-sections), evaluation methodology, experimental results (include tables), limitations, and future work. Use formal academic tone, cite at least eight landmark papers inline (APA), and end with a concise conclusion.", "messages": "[{\"role\": \"system\", \"content\": \"Please reason step by step, and put your final answer within \\\\boxed{}.\"}, {\"role\": \"user\", \"content\": \"You are a senior systems researcher. Draft a 4,000-word white paper titled “post-training systems for RLHF” with these sections: abstract (≤150 words), introduction, related work, system architecture (with numbered sub-sections), evaluation methodology, experimental results (include tables), limitations, and future work. Use formal academic tone, cite at least eight landmark papers inline (APA), and end with a concise conclusion.\"}]", "ground_truth": "1", "case_type": "", "test_case_function": "", "test_cases": "", "tag": "deepmath_103k"} diff --git a/roll/datasets/collator.py b/roll/datasets/collator.py index 1fefec2f..1455428a 100644 --- a/roll/datasets/collator.py +++ b/roll/datasets/collator.py @@ -10,6 +10,9 @@ from transformers import DataCollatorForSeq2Seq, PreTrainedTokenizerBase, ProcessorMixin, BatchFeature from transformers.data.data_collator import pad_without_fast_tokenizer_warning from transformers.utils import PaddingStrategy +from roll.utils.logging import get_logger + +logger = get_logger() def collate_fn_to_dict_list(data_list: list[dict]) -> dict: @@ -98,6 +101,20 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: padded_features = [{k: v for k, v in feature.items() if k in self.padded_keys} for feature in features] un_padded_features = [{k: v for k, v in feature.items() if k not in self.padded_keys} for feature in features] + # Debug: Log the input features + logger.info(f"COLLATOR_DEBUG: Processing {len(features)} features") + for i, feature in enumerate(features): + if 'input_ids' in feature: + input_ids = feature['input_ids'] + logger.info(f"COLLATOR_DEBUG: Feature_{i}: input_ids_len={len(input_ids)}, input_ids_first_10={input_ids[:10]}") + + # Log any text content for comparison + for key in ['prompt', 'text', 'messages', 'ground_truth']: + if key in feature: + text_data = feature[key] + sample_text = str(text_data)[:100] if text_data else "None" + logger.info(f"COLLATOR_DEBUG: Feature_{i}: {key}_sample='{sample_text}'") + batch = pad_without_fast_tokenizer_warning( self.tokenizer, padded_features, @@ -106,6 +123,13 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors, ) + + # Debug: Log the output batch + if 'input_ids' in batch: + logger.info(f"COLLATOR_DEBUG: Output batch: input_ids_shape={batch['input_ids'].shape}") + for i in range(len(batch['input_ids'])): + logger.info(f"COLLATOR_DEBUG: Batch_output_{i}: input_ids_first_10={batch['input_ids'][i][:10].tolist()}") + batch["position_ids"] = torch.clip(torch.cumsum(batch["attention_mask"], dim=-1) - 1, min=0, max=None) un_padded_batch = collate_fn_to_dict_list(un_padded_features) batch.update(un_padded_batch) diff --git a/roll/distributed/scheduler/generate_scheduler.py b/roll/distributed/scheduler/generate_scheduler.py index ab481c14..79874421 100644 --- a/roll/distributed/scheduler/generate_scheduler.py +++ b/roll/distributed/scheduler/generate_scheduler.py @@ -3,8 +3,6 @@ import queue import random import threading -import asyncio -import uuid import time from collections import defaultdict from typing import Any, Union, Optional, Dict, List, Set @@ -30,7 +28,7 @@ ) from roll.utils.logging import get_logger from roll.utils.multi_thread_utils import ThreadSafeDict - +from pprint import pprint logger = get_logger() @@ -182,11 +180,12 @@ def generate(self, data: DataProto, actor_cluster: Union[Any, Cluster], pipeline def get_available_dp_rank(self): while True: # 负载均衡逻辑,期望各dp 正在处理的条数基本接近 - sorted_ranks = sorted( - self.load_balance_coordinator.keys(), key=lambda rank: (self.load_balance_coordinator[rank], rank) - ) - if self.load_balance_coordinator[sorted_ranks[0]] < self.max_running_requests: - yield sorted_ranks[0] + with self.lock: + sorted_ranks = sorted( + self.load_balance_coordinator.keys(), key=lambda rank: (self.load_balance_coordinator[rank], rank) + ) + if self.load_balance_coordinator[sorted_ranks[0]] < self.max_running_requests: + yield sorted_ranks[0] def send_request_to_one_worker(self, data: DataProto): dp_rank = next(self.get_available_dp_rank()) @@ -289,6 +288,7 @@ def generate_opt_level_1(self, data: DataProto): output=output_tensor, num_return_sequences=generate_return_num, sequence_length=self.pipeline_config.sequence_length, + canonical_prompt_length=self.pipeline_config.prompt_length, eos_token_id=eos_token_id, pad_token_id=pad_token_id, ) @@ -390,6 +390,7 @@ def __init__(self, pipeline_config=None): self.running_prompts = 0 self.response_cache: Dict[str, List] = None self.prompt_use_count = 0 + self.postprocessed_requests_count = 0 def set_scheduler( self, @@ -443,6 +444,27 @@ def set_scheduler( namespace=RAY_NAMESPACE, ).remote() + import os + os.environ.setdefault("PYDEVD_USE_CYTHON", "NO") + os.environ.setdefault("PYDEVD_USE_FRAME_EVAL", "NO") + import pydevd_pycharm + + # Differentiate schedulers by use_additional_prompts config + if hasattr(self.pipeline_config, 'use_additional_prompts') and not self.pipeline_config.use_additional_prompts: + # Validation scheduler (use_additional_prompts = False) + debug_port = 12346 + scheduler_type = "VALIDATION" + else: + # Training scheduler (use_additional_prompts = True or default) + debug_port = 12344 + scheduler_type = "TRAINING" + logger.info(f"Connecting PyCharm debugger on port {debug_port}") + if os.getenv("PYCHARM", "0") == "1": + pydevd_pycharm.settrace('localhost', port=debug_port, stdoutToServer=True, stderrToServer=True, suspend=False) + logger.info(f"PyCharm debugger attached to {scheduler_type} scheduler on port {debug_port}") + + + def reset_status(self): self.completed_buffers: Dict[int, List[DataProto]] = defaultdict(list) self.query_group_buffers: Dict[int, List[DataProto]] = defaultdict(list) @@ -466,6 +488,8 @@ def reset_status(self): desc=f"{bar_name} generate progress(prompt)", mininterval=int(self.batch_size * 0.1) + 1, ) + self.interrupted_query_group_buffers: Dict[int, List[DataProto]] = defaultdict(list) + def get_batch(self, data: DataProto, batch_size: int) -> DataProto: """ @@ -479,6 +503,12 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto: prompt_id_counter = itertools.count() self.generation_config = copy.deepcopy(data.meta_info["generation_config"]) num_return_sequences = self.generation_config["num_return_sequences"] + + enable_dynamic_lb = True # enable dynamic load balance, if False, we will not check dp worker load and interrupt requests + rebalance_threshold = 1 # rebalance threshold, if the dp worker load difference is larger than this value, we will interrupt some requests + + last_interruption_time = None + freeze_interruption_timeout = 3 # after each interruption, wait 3 seconds to update load balancer from scheduler before check rebalance threshold while True: if ( sum([len(v) for v in list(self.completed_buffers.values())[:]]) @@ -488,13 +518,64 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto: break self.check_worker_alive(self.actor_cluster) self.check_response_callback() - if not self.check_send_new_request(): - time.sleep(1) + + + if enable_dynamic_lb and (last_interruption_time is None or time.time() - last_interruption_time > freeze_interruption_timeout): + # check the difference between dp workers, + # if the difference is too large > 2 , interrupt all requests on the most loaded dp worker aggresively , + # the load balance coordinator will handle the reroute + with self.lock: + dp_imbalance_range = max(self.load_balance_coordinator.values()) - min(self.load_balance_coordinator.values()) + max_rank = max(self.load_balance_coordinator, key=self.load_balance_coordinator.get) + # leftover = dp_imbalance_range - interrupt_cnt = rebalance_threshold /2 + interrupt_cnt = dp_imbalance_range // 2 + leftover_cnt = self.load_balance_coordinator[max_rank] - interrupt_cnt + + if dp_imbalance_range >= rebalance_threshold and interrupt_cnt > 0: + # find the most loaded dp worker + + logger.info(f"Dynamic load balance: max dp rank {max_rank} load_balance_coordinator {self.load_balance_coordinator}") + + self.actor_cluster.workers[max_rank].add_request.remote( + command=GenerateRequestType.INTERRUPT, data=DataProto(meta_info={"target_leftover_cnt": leftover_cnt}) + ) + + last_interruption_time = time.time() + + + if self.interrupted_query_group_buffers: + + + target_dp_rank = next(self.get_available_dp_rank()) + + self.send_one_interrupted_query_group_to_dp_new(target_dp_rank) + logger.info(f"Migration: Sending interrupted query group to DP rank {target_dp_rank} and skip sending any new request.") + continue + if not self.check_send_new_request(): + req_bf = set(self.requests_buffers.keys()) + callback_request_ids = req_bf - self.abort_request_ids + if not callback_request_ids: + buffer_len = len(self.completed_buffers) + logger.info(f"requests callback buffers empty {req_bf=} - {self.abort_request_ids=} , buffer len {buffer_len}, sleep 5 sec to finish all requests shall finish loop soon, ") + time.sleep(5) + + + else: + + logger.info(f"buffer is not empty, {callback_request_ids=}, sleep 1 sec") + + time.sleep(1) + continue + # get a query from dataset prompt_id = next(prompt_id_counter) - dataset_item = self.get_next_dataset_item() + dataset_item = self.get_fixed_dataset_item(0) # Use fixed dataset item for testing + # dataset_item = self.get_next_dataset_item() # Use different dataset items + # tao: rvst hardcode for testing debug interrupt, remove this later + # dataset_item = self.get_fixed_dataset_item(0) # This was causing identical input_ids! + domain = dataset_item.get("domain", "default") collect_data = self.collect_fn([dataset_item]) request_data: DataProto = DataProto.from_single_dict(collect_data, meta_info=data.meta_info) @@ -503,6 +584,7 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto: request_data_list = self.expand_requests(request_data) dp_rank = next(self.get_available_dp_rank()) + dp_rank = 0 with self.lock: self.prompt_use_count += 1 self.running_prompts += 1 @@ -526,14 +608,44 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto: completed_buffers = {k: v for k, v in self.completed_buffers.items() if len(v) > 0} collect_data = [item for sublist in list(completed_buffers.values())[:] for item in sublist] + assert len(collect_data) > 0, f"No collect data found, check if the dataset is empty or all requests are filtered out. {self.completed_buffers}" + # **LOG ALL INDIVIDUAL REQUESTS**: Before concatenation, log each request's detailed info + logger.info(f"SCHEDULER_COLLECT_DATA: Found {len(collect_data)} total responses from {len(completed_buffers)} queries") + for i, data_item in enumerate(collect_data): + request_id = data_item.meta_info.get("request_id", f"unknown_{i}") + finish_status = data_item.meta_info.get("finish_status", "UNKNOWN") + is_continued = data_item.meta_info.get("is_continued_request", False) + migration_count = data_item.meta_info.get("migration_count", 0) + original_request_id = data_item.meta_info.get("original_request_id", request_id) + domain = data_item.non_tensor_batch.get("domain", ["UNKNOWN"])[0] if "domain" in data_item.non_tensor_batch else "UNKNOWN" + + # Decode prompt and response for logging + if "prompts" in data_item.batch: + prompt_text = self.tokenizer.decode(data_item.batch["prompts"][0], skip_special_tokens=True) + else: + prompt_text = "NO_PROMPT_IN_BATCH" + + if "responses" in data_item.batch: + response_text = self.tokenizer.decode(data_item.batch["responses"][0], skip_special_tokens=True) + response_length = len(data_item.batch["responses"][0]) + else: + response_text = "NO_RESPONSE_IN_BATCH" + response_length = 0 + + logger.info(f"COLLECT_request_id={request_id}, original_id={original_request_id}, domain={domain}, is_continued={is_continued}, migrations={migration_count}, finish_status={finish_status}, response_length={response_length}") + logger.info(f"COLLECT_PROMPT_{request_id}: \n{prompt_text}") + logger.info(f"COLLECT_RESPONSE_{request_id}: \n{response_text}") + query_use_count = next(prompt_id_counter) logger.info( f"total collect data: {len(collect_data)}, collect queries: {len(completed_buffers)} " f"used queries: {query_use_count} query_filter_count: {self.query_filter_count} " f"response_filter_count: {self.response_filter_count}" ) + # TODO: 这里 len(collect_data) > rollout_batch_size, 可以尝试动态扩大batch_size - batch = DataProto.concat(collect_data[: self.batch_size * num_return_sequences]) + batch_size = min(self.batch_size * num_return_sequences, len(collect_data)) + batch = DataProto.concat(collect_data[: batch_size]) batch.meta_info["metrics"] = { f"scheduler/query_filter_count": self.query_filter_count, f"scheduler/response_filter_count": self.response_filter_count, @@ -555,25 +667,275 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto: return batch + def send_one_interrupted_query_group_to_dp_new(self, target_dp_rank: int): + """ + Send interrupted query group to a target DP rank for continuation. + Enhanced to handle multiple interruptions/migrations by tracking original prompt length + and cumulative partial output length. + """ + with self.lock: + assert self.interrupted_query_group_buffers, "Migration: No interrupted query groups in buffer to migrate" + + prompt_id, interrupted_batches = self.interrupted_query_group_buffers.popitem() + assert len(interrupted_batches) > 0, f"Migration: Empty interrupted batches for prompt_id {prompt_id}" + + # --- AGGREGATE ALL PARTIAL OUTPUTS --- + # The core fix is to process all partial outputs for a request together, + # not individually in a loop. + + # 1. Get the original request details from the first interrupted batch. + # All batches for a prompt_id share the same original request. + first_batch = interrupted_batches[0] + original_request_id = first_batch.meta_info["request_id"] + assert original_request_id in self.requests_buffers, f"Original request not found for {original_request_id}" + original_request = self.requests_buffers[original_request_id] + original_prompt_ids = original_request.batch["input_ids"] + original_attention_mask = original_request.batch["attention_mask"] + original_prompt_length = original_prompt_ids.shape[1] + + # 2. Concatenate all partial token chunks from all interruptions. + all_partial_tokens = [] + for batch in interrupted_batches: + partial_tokens = batch.meta_info.get("output_token_ids", []) + # Validate that we only have a single sequence + if partial_tokens and len(partial_tokens) > 1: + logger.error( + f"Migration error: Detected {len(partial_tokens)} sequences in interrupted batch " + f"for request_id={batch.meta_info.get('request_id', 'unknown')}, " + f"prompt_id={prompt_id}" + ) + raise AssertionError( + f"Multiple sequences not supported for request interruption/migration. " + f"Found {len(partial_tokens)} sequences in output_token_ids. " + f"Please use is_num_return_sequences_expand=True to handle multiple sequences as separate requests." + ) + if partial_tokens and len(partial_tokens) > 0 and len(partial_tokens[0]) > 0: + all_partial_tokens.extend(partial_tokens[0]) + + cumulative_partial_output_length = len(all_partial_tokens) + + # 3. Build the new, fully continued input. + if cumulative_partial_output_length > 0: + partial_output_tensor = torch.tensor(all_partial_tokens, device=original_prompt_ids.device) + continued_input_ids = torch.cat([original_prompt_ids.squeeze(0), partial_output_tensor], dim=0).unsqueeze(0) + + # Extend the attention mask to cover the new tokens. + partial_output_mask = torch.ones((1, cumulative_partial_output_length), device=original_prompt_ids.device, dtype=torch.long) + continued_attention_mask = torch.cat([original_attention_mask, partial_output_mask], dim=1) + else: + # No new tokens, resubmit the original prompt. + continued_input_ids = original_prompt_ids + continued_attention_mask = original_attention_mask + + continued_position_ids = torch.arange(continued_input_ids.shape[1], device=continued_input_ids.device).unsqueeze(0) + + # --- NEW ASSERTIONS to validate the fix --- + expected_total_length = original_prompt_length + cumulative_partial_output_length + actual_total_length = continued_input_ids.shape[1] + assert actual_total_length == expected_total_length, \ + f"Assertion Failed: Reconstructed length mismatch. Expected {expected_total_length}, got {actual_total_length}" + assert continued_input_ids.shape[1] == continued_attention_mask.shape[1], \ + f"Assertion Failed: Mismatch between input_ids shape ({continued_input_ids.shape[1]}) and attention_mask shape ({continued_attention_mask.shape[1]})" + + # 4. Create the single new migrated request. + # Use metadata from the *last* interrupted batch as it's the most recent. + last_batch = interrupted_batches[-1] + migrated_request = DataProto() + + batch_tensors = { + "input_ids": continued_input_ids, + "attention_mask": continued_attention_mask, + "position_ids": continued_position_ids, + } + # Copy other tensor fields from the original request + for key in original_request.batch.keys(): + if key not in batch_tensors: + batch_tensors[key] = original_request.batch[key] + + migrated_request.batch = TensorDict(source=batch_tensors, batch_size=[1]) + migrated_request.non_tensor_batch = copy.deepcopy(original_request.non_tensor_batch) + + migrated_request.meta_info = last_batch.meta_info.copy() + migrated_request.meta_info.pop('finish_status', None) + migrated_request.meta_info["response_callback_fn"] = self.response_callback_fn + migrated_request.meta_info["is_continued_request"] = True + migrated_request.meta_info["original_prompt_length"] = original_prompt_length + migrated_request.meta_info["cumulative_partial_output_length"] = cumulative_partial_output_length + migrated_request.meta_info["migration_count"] = len(interrupted_batches) + + # Adjust max_new_tokens for the continued generation. + generation_config = migrated_request.meta_info["generation_config"].copy() + max_sequence_length = self.pipeline_config.sequence_length + original_max_new_tokens = generation_config["max_new_tokens"] + + # Calculate remaining tokens from original request + remaining_from_original_request = original_max_new_tokens - cumulative_partial_output_length + # Calculate remaining space in sequence + remaining_sequence_space = max_sequence_length - actual_total_length + # Take minimum of both constraints + max_allowed_new_tokens = min(remaining_from_original_request, remaining_sequence_space) + generation_config["max_new_tokens"] = max_allowed_new_tokens + migrated_request.meta_info["generation_config"] = generation_config + + # Reuse the original request ID for the resumed request. + migrated_request.meta_info["request_id"] = original_request_id + + # Update the buffers and mappings with the new state for the original request ID. + self.requests_buffers[original_request_id] = migrated_request + self.request_id_2_prompt_id[original_request_id] = prompt_id + self.request_id_2_dp_rank[original_request_id] = target_dp_rank + # The original_request_id should already be in self.prompt_id_2_request_ids, so no need to re-add. + + ray.get(self.actor_cluster.workers[target_dp_rank].add_request.remote(command=GenerateRequestType.ADD, + data=migrated_request)) + self.load_balance_coordinator[target_dp_rank] += 1 + logger.info( + f"Successfully resumed prompt {prompt_id} to dp rank {target_dp_rank} with original request id {original_request_id}" + ) + + + def interrupt_all_requests_by_dp_rank(self, interrupted_rank): + # assert False, "Migration: interrupt_all_requests_by_dp_rank is not implemented yet" + # return + # 1. remove the interrupted rank from the active dp ranks + logger.info(f"Migration: Removing DP rank {interrupted_rank} from ready ranks") + # self.ready_dp_ranks.remove(interrupted_rank) + + # some might be interrupted, some might be aborted or interrupted_rank + request_ids = self.get_running_request_ids_for_dp_rank(interrupted_rank) + assert len(request_ids) > 0, "no requests are informed interruption" + logger.info( + f"Migration: inform interrupting {len(request_ids)} requests from DP rank {interrupted_rank} request list: {request_ids} ") + interrupt_refs = [] + + for request_id in request_ids: + # dp_rank = self.request_id_2_dp_rank[request_id] + interrupt_refs.append( + self.actor_cluster.workers[interrupted_rank].add_request.remote( + command=GenerateRequestType.INTERRUPT, data=DataProto(meta_info={"request_id": request_id}) + ) + ) + + def get_running_request_ids_for_dp_rank(self, target_dp_rank: int) -> List[str]: + """Get all request_ids currently assigned to a specific DP rank""" + running_request_ids = [] + with self.lock: + # all unfinished requests + for request_id in self.requests_buffers.keys(): + if self.request_id_2_dp_rank[request_id] == target_dp_rank and request_id not in self.abort_request_ids: + running_request_ids.append(request_id) + + return running_request_ids + @ray.method(concurrency_group="multi_thread") def report_response(self, data: DataProto): """ 这里需要考虑多线程数据访问 data 返回可能有多条的 """ + + + try: + logger.info(f"report_response: {data.meta_info['request_id']} {data.meta_info['finish_status']}") request_id = data.meta_info["request_id"] + # if request_id == '5': + # if False: + # pydevd_pycharm.settrace( + # 'localhost', + # port=12332, + # stdoutToServer=True, + # stderrToServer=True, + # suspend=False, + # trace_only_current_thread=True + # ) + # else: + # # pydevd_pycharm.settrace( + # # 'localhost', + # # port=9999, + # # stdoutToServer=True, + # # stderrToServer=True, + # # suspend=False, + # # trace_only_current_thread=True + # # ) + # + # while True: + # pass + prompt_id = self.request_id_2_prompt_id[request_id] num_return_sequences = self.generation_config["num_return_sequences"] + assert data.meta_info["finish_status"] in ["interrupted", 'finished'] + with self.lock: + if data.meta_info["finish_status"] == "interrupted": + # **ENHANCED LOGGING**: Track prompt length and output lengths for interrupted requests + + # Get the original request to analyze lengths + original_request = self.requests_buffers.get(request_id, None) + original_prompt_length = 0 + cumulative_partial_output_length = 0 + migration_count = 0 + + if original_request: + original_input_ids = original_request.batch["input_ids"] + concatenated_input_length = original_input_ids.shape[1] + + # Check if this is a continued request + is_continued_request = original_request.meta_info.get("is_continued_request", False) + if is_continued_request: + # For continued requests, get the actual original prompt length from metadata + original_prompt_length = original_request.meta_info.get("original_prompt_length", 1024) + cumulative_partial_output_length = original_request.meta_info.get("cumulative_partial_output_length", 0) + migration_count = original_request.meta_info.get("migration_count", 0) + else: + # For original requests, the input length is the original prompt length + original_prompt_length = concatenated_input_length + + # Get the newly generated output tokens from this interruption + output_token_ids = data.meta_info.get("output_token_ids", []) + newly_generated_length = 0 + if output_token_ids and len(output_token_ids) > 0 and len(output_token_ids[0]) > 0: + newly_generated_length = len(output_token_ids[0]) + + # Single comprehensive log entry + logger.info(f"Migration: BUFFERING interrupted request {request_id}: " + f"original_prompt_length={original_prompt_length}, " + f"cumulative_partial_output_length={cumulative_partial_output_length}, " + f"newly_generated_length={newly_generated_length}, " + f"migration_count={migration_count}") + + self.interrupted_query_group_buffers[prompt_id].append(data) + logger.info( + f"Migration: Added interrupted batch for prompt_id {prompt_id} to buffer {list(self.interrupted_query_group_buffers.keys())}") + # assert False, "can interrupt and buffer the response, but not complete it yet" + self.load_balance_coordinator[self.request_id_2_dp_rank[request_id]] -= 1 + return + + assert data.meta_info["finish_status"] == "finished" + # with lock batch = self.postprocess_output_ids(data) output_count = batch.batch.batch_size[0] + with self.lock: self.load_balance_coordinator[self.request_id_2_dp_rank[request_id]] -= 1 self.prompt_id_2_request_ids[prompt_id].remove(request_id) domain = "default" + assert "domain" in batch.non_tensor_batch.keys(), f"{prompt_id=} {request_id=} batch.non_tensor_batch keys: {list(batch.non_tensor_batch.keys())} should contain domain" + if "domain" in batch.non_tensor_batch.keys(): domain = batch.non_tensor_batch["domain"][0] + + logger.info( + f"{request_id=} batch.non_tensor_batch: {list(batch.non_tensor_batch.keys())} self.reward_worker_iters {list(self.reward_worker_iters.keys())}") + + if domain == "default": + import pydevd_pycharm + import os + if os.getenv("PYCHARM", "0") == "1": + pydevd_pycharm.settrace('localhost', port=12332, stdoutToServer=True, stderrToServer=True, + suspend=False) + assert False, f"batch.non_tensor_batch : {list(batch.non_tensor_batch.keys())} self.reward_worker_iters {list(self.reward_worker_iters.keys())}" + reward_worker = next(self.reward_worker_iters[domain]) if not self.running: @@ -582,7 +944,7 @@ def report_response(self, data: DataProto): # call reward # reward worker得能支持单条数据计算, dynamic sampling对需要batch计算reward的需要注意... # 多域的时候,llm as judge, 需要单独为reward worker分配gpu - rewards: DataProto = ray.get(reward_worker.compute_rewards.remote(batch)) + rewards: DataProto = ray.get(reward_worker.compute_rewards.remote(batch),timeout=30) batch.union(rewards) response_buffers: List[DataProto] = [] @@ -630,6 +992,25 @@ def report_response(self, data: DataProto): except Exception as e: self.exception_queue.put(e) + + def get_fixed_dataset_item(self, dataset_index=0): + """Fixed dataset item for testing - always returns the same item""" + dataset_item = self.dataset[dataset_index] + logger.info(f"FIXED_DATASET_DEBUG: Using fixed dataset item at index {dataset_index}") + + # Log the fixed dataset item details + for key in ['prompt', 'text', 'messages', 'ground_truth', 'input_ids']: + if key in dataset_item: + + data = dataset_item[key] + if key == 'input_ids': + logger.info(f"FIXED_DATASET_DEBUG: {key}_len={len(data)}, first_10={data[:10]}") + else: + sample_text = str(data)[:100] if data else "None" + logger.info(f"FIXED_DATASET_DEBUG: {key}_sample='{sample_text}'") + + return dataset_item + def get_next_dataset_item(self): if self.dataset_iter is None: random.seed(self.pipeline_config.seed + self.dataset_epoch) @@ -638,7 +1019,13 @@ def get_next_dataset_item(self): logger.info(f"{'-'.join(self.reward_clusters.keys())} dataset epoch: {self.dataset_epoch}") try: - dataset_item = self.dataset[next(self.dataset_iter)] + item_index = next(self.dataset_iter) + logger.info(f"Dataset length: {len(self.dataset)}, retrieving get_next_dataset_item at index: {item_index}") + dataset_item = self.dataset[item_index] + + # dataset_item = self.dataset[next(self.dataset_iter)] + # tao: rvst hardcode for testing debug interrupt, remove this later + # dataset_item = self.dataset[0] except StopIteration: self.dataset_epoch += 1 random.seed(self.pipeline_config.seed + self.dataset_epoch) @@ -654,21 +1041,29 @@ def get_scheduler_state(self): def abort_requests(self, request_ids: Set[str]): abort_refs = [] - self.running_prompts -= 1 + request_ids = request_ids.copy() for request_id in request_ids: + assert request_id not in self.abort_request_ids dp_rank = self.request_id_2_dp_rank[request_id] self.load_balance_coordinator[dp_rank] -= 1 + prompt_id = self.request_id_2_prompt_id[request_id] + self.prompt_id_2_request_ids[prompt_id].remove(request_id) + if len(self.prompt_id_2_request_ids[prompt_id]) == 0: + self.running_prompts -= 1 abort_refs.append( self.actor_cluster.workers[dp_rank].add_request.remote( command=GenerateRequestType.ABORT, data=DataProto(meta_info={"request_id": request_id}) ) ) + self.abort_request_ids.add(request_id) def postprocess_output_ids(self, data: DataProto) -> DataProto: # postprocess_generate, input_ids, attention_mask, left pad request_id = data.meta_info["request_id"] - request: DataProto = self.requests_buffers.pop(request_id) - + logger.info(f"postprocess_output_ids: {request_id=}") + with self.lock: + request: DataProto = self.requests_buffers.pop(request_id) + self.postprocessed_requests_count +=1 eos_token_id = data.meta_info["eos_token_id"] pad_token_id = data.meta_info["pad_token_id"] output_token_ids = data.meta_info["output_token_ids"] @@ -682,6 +1077,7 @@ def postprocess_output_ids(self, data: DataProto) -> DataProto: output=output_tensor, num_return_sequences=len(output_tokens), sequence_length=self.pipeline_config.sequence_length, + canonical_prompt_length=self.pipeline_config.prompt_length, eos_token_id=eos_token_id, pad_token_id=pad_token_id, ) @@ -738,11 +1134,12 @@ def check_send_new_request(self) -> bool: def get_available_dp_rank(self): while True: # 负载均衡逻辑,期望各dp 正在处理的条数基本接近 - sorted_ranks = sorted( - self.load_balance_coordinator.keys(), key=lambda rank: (self.load_balance_coordinator[rank], rank) - ) - if self.load_balance_coordinator[sorted_ranks[0]] < self.max_running_requests: - yield sorted_ranks[0] + with self.lock: + sorted_ranks = sorted( + self.load_balance_coordinator.keys(), key=lambda rank: (self.load_balance_coordinator[rank], rank) + ) + if self.load_balance_coordinator[sorted_ranks[0]] < self.max_running_requests: + yield sorted_ranks[0] @ray.remote @@ -755,38 +1152,34 @@ def get_value(self): return self.value -@ray.remote +@ray.remote(concurrency_groups={"single_thread": 1, "multi_thread": 2048}) class RequestScheduler: def __init__(self, infer_cluster, pipeline_config): self.infer_cluster = infer_cluster self.pipeline_config = pipeline_config - self.request_id = uuid.uuid4() - self.request_counter = 0 - self.src_rank2_dp_rank = {} + self.request_dict = ThreadSafeDict() self.request_id_2_dp_rank = {} - self.inflight_requests: List[Dict[str, asyncio.Future]] = [{} for _ in range(self.infer_cluster.world_size)] + self.src_rank2_dp_rank = {} self.worker_iter = itertools.cycle(range(self.infer_cluster.world_size)) - self.need_suspend = False - self.suspend_notifier = asyncio.Event() - - async def generate_one_request(self, data: DataProto): - await self._check_suspend() + @ray.method(concurrency_group="multi_thread") + def generate_one_request(self, data: DataProto): + assert "request_id" in data.meta_info, f"data {data.meta_info} should have key 'request_id'" + request_id = data.meta_info["request_id"] src_rank = data.meta_info["src_rank"] if src_rank not in self.src_rank2_dp_rank: dp_rank = next(self.worker_iter) self.src_rank2_dp_rank[src_rank] = dp_rank + dp_rank = self.src_rank2_dp_rank[src_rank] - request_id = f"{self.request_id}_{self.request_counter}" - self.request_counter += 1 - data.meta_info["request_id"] = request_id - fut = asyncio.Future() + # send request to one worker + ray.get(self.infer_cluster.workers[dp_rank].add_request.remote(command=GenerateRequestType.ADD, data=data)) + data.meta_info.pop("response_callback_fn") self.request_id_2_dp_rank[request_id] = dp_rank - self.inflight_requests[dp_rank][request_id] = fut - ref = self.infer_cluster.workers[dp_rank].add_request.remote(command=GenerateRequestType.ADD, data=data) - await asyncio.wrap_future(ref.future()) - response_data = await fut + + response_data: DataProto = self.request_dict.pop(data.meta_info["request_id"]) + self.request_id_2_dp_rank.pop(data.meta_info["request_id"]) if response_data is None: # request aborted return None @@ -813,46 +1206,170 @@ async def generate_one_request(self, data: DataProto): output.meta_info = request_repeat.meta_info return output - async def report_response(self, data: DataProto, is_abort=False): - request_id = data.meta_info["request_id"] - if request_id not in self.request_id_2_dp_rank: - return - dp_rank = self.request_id_2_dp_rank.pop(request_id) - fut = self.inflight_requests[dp_rank].pop(request_id) - if is_abort: - fut.set_result(None) - else: - fut.set_result(data) + @ray.method(concurrency_group="multi_thread") + def report_response(self, data: DataProto): + """ + 这里需要考虑多线程数据访问 + data 返回可能有多条的 + """ - async def abort_request(self): - futures = [] - for i in range(self.infer_cluster.world_size): - if len(self.inflight_requests[i]) == 0: - continue - ref = self.infer_cluster.workers[i].add_request.remote( - command=GenerateRequestType.ABORT, data=DataProto( - meta_info={"request_id": [request_id for request_id in self.inflight_requests[i].keys()]} + import pydevd_pycharm + + + try: + logger.info(f"report_response: {data.meta_info['request_id']} {data.meta_info['finish_status']}") + request_id = data.meta_info["request_id"] + # if request_id == '5': + # if False: + # pydevd_pycharm.settrace( + # 'localhost', + # port=12332, + # stdoutToServer=True, + # stderrToServer=True, + # suspend=False, + # trace_only_current_thread=True + # ) + # else: + # # pydevd_pycharm.settrace( + # # 'localhost', + # # port=9999, + # # stdoutToServer=True, + # # stderrToServer=True, + # # suspend=False, + # # trace_only_current_thread=True + # # ) + # + # while True: + # pass + + prompt_id = self.request_id_2_prompt_id[request_id] + num_return_sequences = self.generation_config["num_return_sequences"] + assert data.meta_info["finish_status"] in ["interrupted", 'finished'] + with self.lock: + if data.meta_info["finish_status"] == "interrupted": + # **ENHANCED LOGGING**: Track prompt length and output lengths for interrupted requests + + # Get the original request to analyze lengths + original_request = self.requests_buffers.get(request_id, None) + original_prompt_length = 0 + cumulative_partial_output_length = 0 + migration_count = 0 + + if original_request: + original_input_ids = original_request.batch["input_ids"] + concatenated_input_length = original_input_ids.shape[1] + + # Check if this is a continued request + is_continued_request = original_request.meta_info.get("is_continued_request", False) + if is_continued_request: + # For continued requests, get the actual original prompt length from metadata + original_prompt_length = original_request.meta_info.get("original_prompt_length", 1024) + cumulative_partial_output_length = original_request.meta_info.get("cumulative_partial_output_length", 0) + migration_count = original_request.meta_info.get("migration_count", 0) + else: + # For original requests, the input length is the original prompt length + original_prompt_length = concatenated_input_length + + # Get the newly generated output tokens from this interruption + output_token_ids = data.meta_info.get("output_token_ids", []) + newly_generated_length = 0 + if output_token_ids and len(output_token_ids) > 0 and len(output_token_ids[0]) > 0: + newly_generated_length = len(output_token_ids[0]) + + # Single comprehensive log entry + logger.info(f"Migration: BUFFERING interrupted request {request_id}: " + f"original_prompt_length={original_prompt_length}, " + f"cumulative_partial_output_length={cumulative_partial_output_length}, " + f"newly_generated_length={newly_generated_length}, " + f"migration_count={migration_count}") + + self.interrupted_query_group_buffers[prompt_id].append(data) + logger.info( + f"Migration: Added interrupted batch for prompt_id {prompt_id} to buffer {list(self.interrupted_query_group_buffers.keys())}") + # assert False, "can interrupt and buffer the response, but not complete it yet" + self.load_balance_coordinator[self.request_id_2_dp_rank[request_id]] -= 1 + return + + assert data.meta_info["finish_status"] == "finished" + + # with lock + batch = self.postprocess_output_ids(data) + output_count = batch.batch.batch_size[0] + + with self.lock: + self.load_balance_coordinator[self.request_id_2_dp_rank[request_id]] -= 1 + self.prompt_id_2_request_ids[prompt_id].remove(request_id) + domain = "default" + assert "domain" in batch.non_tensor_batch.keys(), f"{prompt_id=} {request_id=} batch.non_tensor_batch keys: {list(batch.non_tensor_batch.keys())} should contain domain" + + if "domain" in batch.non_tensor_batch.keys(): + domain = batch.non_tensor_batch["domain"][0] + + logger.info( + f"{request_id=} batch.non_tensor_batch: {list(batch.non_tensor_batch.keys())} self.reward_worker_iters {list(self.reward_worker_iters.keys())}") + + if domain == "default": + import pydevd_pycharm + import os + if os.getenv("PYCHARM", "0") == "1": + pydevd_pycharm.settrace('localhost', port=12332, stdoutToServer=True, stderrToServer=True, + suspend=False) + assert False, f"batch.non_tensor_batch : {list(batch.non_tensor_batch.keys())} self.reward_worker_iters {list(self.reward_worker_iters.keys())}" + + reward_worker = next(self.reward_worker_iters[domain]) + + if not self.running: + return + + # call reward + # reward worker得能支持单条数据计算, dynamic sampling对需要batch计算reward的需要注意... + # 多域的时候,llm as judge, 需要单独为reward worker分配gpu + rewards: DataProto = ray.get(reward_worker.compute_rewards.remote(batch), timeout=30) + batch.union(rewards) + + response_buffers: List[DataProto] = [] + batch_expanded = [batch[[idx]] for idx in range(output_count)] + + # response_filter, 不太需要response filter + for batch_item in batch_expanded: + if self.response_filter_fn(batch_item, self.pipeline_config): + response_buffers.append(batch_item) + else: + self.response_filter_count += 1 + + with self.lock: + self.response_cache[domain].extend(batch_expanded) + + if len(response_buffers) == 0: + if len(self.prompt_id_2_request_ids[prompt_id]) == 0: + self.running_prompts -= 1 + return + + if len(self.completed_buffers[prompt_id]) > 0: + return + + # expand batch to response + self.query_group_buffers[prompt_id].extend(response_buffers) + + # query_filter, query has n responses + if len(self.query_group_buffers[prompt_id]) >= num_return_sequences: + if not self.query_filter_fn(self.query_group_buffers[prompt_id], self.pipeline_config): + self.query_filter_count += 1 + del self.query_group_buffers[prompt_id] + self.abort_requests(self.prompt_id_2_request_ids[prompt_id]) + return + + assert len(self.query_group_buffers[prompt_id]) >= num_return_sequences, ( + f"expect to generate {num_return_sequences} results from one prompt, " + f"but get {len(self.query_group_buffers[prompt_id])}." ) - ) - futures.append(ref) - for request_id in self.inflight_requests[i].keys(): - futures.append(self.report_response(data=DataProto(meta_info={"request_id": request_id}), is_abort=True)) - # must await at last, because report_response will mut inflight_requests - await asyncio.gather(*futures) - - async def _check_suspend(self): - while self.need_suspend: - await self.suspend_notifier.wait() - - async def suspend(self): - if self.need_suspend: - return - self.suspend_notifier.clear() - self.need_suspend = True - await self.abort_request() - def resume(self): - if not self.need_suspend: - return - self.need_suspend = False - self.suspend_notifier.set() + self.completed_buffers[prompt_id] = self.query_group_buffers[prompt_id][:num_return_sequences] + self.progress_bar.update() + + # abort uncompleted request + self.abort_requests(self.prompt_id_2_request_ids[prompt_id]) + except Exception as e: + self.exception_queue.put(e) + + pydevd_pycharm.stoptrace() diff --git a/roll/distributed/scheduler/protocol.py b/roll/distributed/scheduler/protocol.py index 9f4bcd06..f34eff2c 100644 --- a/roll/distributed/scheduler/protocol.py +++ b/roll/distributed/scheduler/protocol.py @@ -18,6 +18,9 @@ from torch.utils.data import DataLoader from roll.utils.functionals import union_two_dict, divide_by_chunk_size +from roll.utils.logging import get_logger + +logger = get_logger() try: tensordict.set_lazy_legacy(False).set() @@ -606,7 +609,86 @@ def concat(data: List["DataProto"]) -> "DataProto": for batch in data: if batch.batch is not None: batch_lst.append(batch.batch) + if len(batch_lst) > 0 and batch_lst[0] is not None: + # Add comprehensive logging to verify tensor dimensions before concatenation + logger.info(f"[CONCAT] Attempting to concatenate {len(batch_lst)} batches") + + # Collect all tensor keys and verify dimensions + all_keys = set() + for i, batch in enumerate(batch_lst): + all_keys.update(batch.keys()) + + logger.info(f"[CONCAT] Found tensor keys: {list(all_keys)}") + + # Verify dimensions for each key across all batches + dimension_mismatches = [] + for key in all_keys: + shapes = [] + request_ids = [] + + for i, batch in enumerate(batch_lst): + if key in batch: + shape = batch[key].shape + shapes.append(shape) + # Try to get request_id from meta_info if available + request_id = "unknown" + if i < len(data) and data[i].meta_info and "request_id" in data[i].meta_info: + request_id = data[i].meta_info["request_id"] + request_ids.append(request_id) + + # Check for dimension mismatches (excluding dim=0 which should vary) + if len(shapes) > 1: + expected_shape = shapes[0][1:] # All dimensions except batch dimension + for j, shape in enumerate(shapes[1:], 1): + actual_shape = shape[1:] # All dimensions except batch dimension + if expected_shape != actual_shape: + mismatch_info = { + "key": key, + "expected_shape": expected_shape, + "actual_shape": actual_shape, + "expected_request_id": request_ids[0], + "mismatched_request_id": request_ids[j], + "all_shapes": shapes, + "all_request_ids": request_ids + } + dimension_mismatches.append(mismatch_info) + + logger.error(f"[CONCAT] DIMENSION MISMATCH for key '{key}':") + logger.error(f" Expected shape: {expected_shape} (request_id: {request_ids[0]})") + logger.error(f" Actual shape: {actual_shape} (request_id: {request_ids[j]})") + logger.error(f" All shapes for this key: {shapes}") + logger.error(f" All request IDs for this key: {request_ids}") + + # Log summary of all tensor shapes for debugging + logger.info(f"[CONCAT] Summary of tensor shapes across all batches:") + for key in all_keys: + shapes = [] + request_ids = [] + for i, batch in enumerate(batch_lst): + if key in batch: + shapes.append(batch[key].shape) + request_id = "unknown" + if i < len(data) and data[i].meta_info and "request_id" in data[i].meta_info: + request_id = data[i].meta_info["request_id"] + request_ids.append(request_id) + logger.info(f" Key '{key}': shapes={shapes}, request_ids={request_ids}") + + # If there are dimension mismatches, log detailed error and raise exception + if dimension_mismatches: + logger.error(f"[CONCAT] Found {len(dimension_mismatches)} dimension mismatches:") + for mismatch in dimension_mismatches: + logger.error(f" Key: {mismatch['key']}") + logger.error(f" Expected: {mismatch['expected_shape']} (request_id: {mismatch['expected_request_id']})") + logger.error(f" Actual: {mismatch['actual_shape']} (request_id: {mismatch['mismatched_request_id']})") + logger.error(f" All shapes: {mismatch['all_shapes']}") + logger.error(f" All request IDs: {mismatch['all_request_ids']}") + + # Raise the original error with additional context + raise RuntimeError(f"Tensor dimension mismatch detected. Found {len(dimension_mismatches)} mismatches. Check logs for details.") + + # If all dimensions match, proceed with concatenation + logger.info(f"[CONCAT] All tensor dimensions verified. Proceeding with concatenation.") new_batch = torch.cat(batch_lst, dim=0) else: new_batch = None diff --git a/roll/distributed/strategy/vllm_strategy.py b/roll/distributed/strategy/vllm_strategy.py index bf16a461..22eccd34 100644 --- a/roll/distributed/strategy/vllm_strategy.py +++ b/roll/distributed/strategy/vllm_strategy.py @@ -1,31 +1,31 @@ -import asyncio import copy import gc import itertools import os import queue from concurrent import futures -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union, Dict +import asyncio import ray import torch import torch.distributed as dist from torch.nn.utils.rnn import pad_sequence from transformers import set_seed -from vllm import RequestOutput, SamplingParams -from vllm.lora.request import LoRARequest +from mcore_adapter.models.converter.convert_utils import RecvBucketManager +from vllm import SamplingParams, RequestOutput from vllm.utils import random_uuid -from mcore_adapter.models.converter.convert_utils import RecvBucketManager from roll.distributed.executor.worker import Worker from roll.distributed.scheduler.protocol import DataProto from roll.distributed.strategy.strategy import InferenceStrategy -from roll.third_party.vllm import LLM, AsyncLLM +from roll.third_party.vllm import LLM +from roll.third_party.vllm import AsyncLLM from roll.utils.collective import collective -from roll.utils.functionals import GenerateRequestType, concatenate_input_and_output +from roll.utils.functionals import concatenate_input_and_output, GenerateRequestType from roll.utils.logging import get_logger from roll.utils.offload_states import OffloadStateType - +import threading logger = get_logger() @@ -41,10 +41,16 @@ def __init__(self, worker: Worker): self.recv_manager = RecvBucketManager() self.command_queue: Optional[queue.Queue] = None - self.request_metas = {} + self.request_metas = {} # used to keep track of requests to callback self.group_name = "vllm_worker_default" self.running = False + self.interrupted_rid_set = set() + self.lock = threading.Lock() + self.count_calls = {} + + + def initialize(self, model_provider): set_seed(seed=self.worker.pipeline_config.seed) vllm_config = copy.deepcopy(self.worker_config.strategy_args.strategy_config) @@ -75,15 +81,6 @@ def initialize(self, model_provider): "load_format": vllm_config.get("load_format", "dummy"), # use model update passed value } ) - self.is_lora = self.worker_config.model_args.lora_target is not None - if self.is_lora: - lora_kwargs = { - "enable_lora": True, - "max_loras": 1, - "max_lora_rank": self.worker_config.model_args.lora_rank, - } - vllm_config.update(lora_kwargs) - vllm_config["load_format"] = "auto" # enables vLLM to load the base model for add_lora logger.info(f"vllm_config: {vllm_config}") assert not dist.is_initialized() @@ -133,31 +130,14 @@ def generate(self, batch: DataProto, generation_config) -> torch.Tensor: input_ids = batch.batch["input_ids"] # (bs, prompt_length) attention_mask = batch.batch["attention_mask"] # left-padded attention_mask - vllm_input_args = {} - if "multi_modal_data" in batch.non_tensor_batch: - vllm_input_args["prompts"] = batch.non_tensor_batch["multi_modal_data"] - else: - vllm_input_args["prompt_token_ids"] = gather_unpadded_input_ids( - input_ids=input_ids, attention_mask=attention_mask - ) + # **ASSERTION**: Multi-modal data should be empty + assert "multi_modal_data" not in batch.non_tensor_batch, f"multi_modal_data should be empty but found in batch" + + vllm_input_args = { + "prompt_token_ids": gather_unpadded_input_ids(input_ids=input_ids, attention_mask=attention_mask) + } - lora_requests = None - if self.is_lora: - batch_size = len(input_ids) - lora_int_ids = list(self.model.llm_engine.list_loras()) - if len(lora_int_ids) > 0: - lora_int_id = lora_int_ids[0] - lora_requests = [ - LoRARequest( - lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="dummy_lora_path" - ) - ] * batch_size - vllm_outputs = self.model.generate( - sampling_params=sampling_params, - use_tqdm=False, - lora_request=lora_requests, - **vllm_input_args, - ) + vllm_outputs = self.model.generate(sampling_params=sampling_params, use_tqdm=False, **vllm_input_args) # (bs * num_return_sequences, max_response_len) output_ids = gather_outputs_to_pad_tensor( @@ -173,22 +153,111 @@ def generate(self, batch: DataProto, generation_config) -> torch.Tensor: return output + def process_interrupted_batch(self, request_id: str, request_complete_callback): + # added req but not started running/waiting + # key should exist + assert request_id in self.request_metas, 'key should exist in request_metas' + + output_data = DataProto(meta_info=self.request_metas[request_id]) + output_data.meta_info["finish_status"] = "interrupted" + output_data.meta_info["output_token_ids"] = [] # No output tokens for interrupted batch + logger.info(f"process_interrupted_batch: request_id {output_data.meta_info['request_id']}") + request_complete_callback(data=output_data) + + def handle_vllm_output(self, finished_vllm_outputs: List[RequestOutput], interrupted_rid_set, request_complete_callback): + finished_req_ids = [ ] + to_callback_output = [] + # handle finshed first, then interrupted + for request_output in finished_vllm_outputs: + assert request_output.finished, "should be finished req" + self.unfinished_vllm_outputs.pop(request_output.request_id, None) + + # still in request_metas not aborted, process the request output + if request_output.request_id in self.request_metas: + + finished_req_ids.append(request_output.request_id) + to_callback_output.append(request_output) + + + + if to_callback_output: + logger.info( + f"process_vllm_output: finished request from fetch_output and in request_metas, request_ids {finished_req_ids} calling callback") + # this assumes req is in self.request_metas + self.process_vllm_output(vllm_outputs=to_callback_output, + request_complete_callback=request_complete_callback) + + # pop the finished request metas after callback + for request_id in finished_req_ids: + self.request_metas.pop(request_id) + + for req_id in interrupted_rid_set: + + if req_id in self.request_metas: + + if req_id in self.unfinished_vllm_outputs: + logger.info(f'handle_vllm_output: interrupted request_id {req_id} has partial output') + # 🔥 ADD DETAILED PARTIAL DECODE LOGGING HERE 🔥 + partial_request_output = self.unfinished_vllm_outputs[req_id] + for i, output in enumerate(partial_request_output.outputs): + partial_text = self.tokenizer.decode(output.token_ids, skip_special_tokens=True) + logger.info(f"INTERRUPTED_PARTIAL: request_id={req_id}, \n" + f"output_{i}_tokens={len(output.token_ids)}, \n" + f"partial_text='{partial_text}', \n" + f"finished={partial_request_output.finished}\n") + + + self.process_vllm_output(vllm_outputs=[self.unfinished_vllm_outputs[req_id]], + request_complete_callback=request_complete_callback) + # self.unfinished_vllm_outputs.pop(req_id) + + else: + logger.info(f"handle_vllm_output: interrupted request_id {req_id} no partial output yet, just process the from added_batch") + self.process_interrupted_batch(req_id, request_complete_callback) + + self.request_metas.pop(req_id) + + else: + logger.warning(f"handle_vllm_output: interrupted request_id {req_id} not found in added_batch, skipping perhaps already finished or aborted") + + interrupted_rid_set.clear() + + + def process_vllm_output(self, vllm_outputs: List[RequestOutput], request_complete_callback): # 转成response id, request_complete_callback for request_output in vllm_outputs: output_token_ids = [] request_id = request_output.request_id if request_id not in self.request_metas: + logger.warning(f"process_vllm_output: request_id {request_id} not in request_metas, skipping") continue for completion_output in request_output.outputs: output_token_ids.append(completion_output.token_ids) output_data = DataProto(meta_info=self.request_metas[request_id]) output_data.meta_info["output_token_ids"] = output_token_ids + + if request_output.finished: + output_data.meta_info["finish_status"] = "finished" + # not interrupted yet, otherwise will be remove from added batch + # if request_id in self.added_batch: + # logger.info(f"process_vllm_output: finished request_id {request_id}") + + else: + output_data.meta_info["finish_status"] = "interrupted" + + logger.info( + f"VLLM RAW OUTPUT: request_id={request_output.request_id}, " + f"output_token_ids={request_output.outputs[0].token_ids}" + ) + request_complete_callback(data=output_data) def start_server(self, data: DataProto, request_complete_callback): collective.barrier(group_name=self.group_name) self.running = True + self.unfinished_vllm_outputs = {} + interrupted_rid_set = set() while True: while not self.command_queue.empty(): command, batch = self.command_queue.get_nowait() @@ -196,6 +265,14 @@ def start_server(self, data: DataProto, request_complete_callback): input_ids = batch.batch["input_ids"] attention_mask = batch.batch["attention_mask"] request_id = batch.meta_info["request_id"] + + # Debug: Log raw tensor info + logger.info(f"RAW_TENSOR_DEBUG: request_id={request_id}, input_ids_id={id(input_ids)}, input_ids_device={input_ids.device}, input_ids_dtype={input_ids.dtype}") + logger.info(f"RAW_TENSOR_DEBUG: request_id={request_id}, input_ids_shape={input_ids.shape}, input_ids_first_10={input_ids[0][:10].tolist()}") + + # Debug: Check if the tensor is a view/shared memory + logger.info(f"RAW_TENSOR_DEBUG: request_id={request_id}, input_ids_is_contiguous={input_ids.is_contiguous()}, input_ids_stride={input_ids.stride()}") + self.request_metas[request_id] = batch.meta_info generation_config = batch.meta_info.get("generation_config") max_new_tokens = batch.meta_info.get("max_new_tokens", generation_config["max_new_tokens"]) @@ -203,39 +280,250 @@ def start_server(self, data: DataProto, request_complete_callback): sampling_params = create_sampling_params_for_vllm( gen_kwargs={**generation_config, "max_new_tokens": max_new_tokens} ) - if "multi_modal_data" in batch.non_tensor_batch: - prompt_token_ids = [ - batch.non_tensor_batch["multi_modal_data"][0] - ["prompt_token_ids"] - ] - multi_modal_data = [ - batch.non_tensor_batch["multi_modal_data"][0] - ["multi_modal_data"] - ] + + # Debug: Check if there's any text in non_tensor_batch + if hasattr(batch, 'non_tensor_batch') and batch.non_tensor_batch: + logger.info(f"RAW_TENSOR_DEBUG: request_id={request_id}, non_tensor_batch_keys={list(batch.non_tensor_batch.keys())}") + + # Check if there's any text we can compare + for key in ['prompt', 'text', 'messages', 'ground_truth']: + if key in batch.non_tensor_batch: + text_data = batch.non_tensor_batch[key] + if hasattr(text_data, '__len__') and len(text_data) > 0: + sample_text = str(text_data[0])[:100] if text_data[0] else "None" + logger.info(f"RAW_TENSOR_DEBUG: request_id={request_id}, {key}_sample='{sample_text}'") + + # **ASSERTION**: Multi-modal data should be empty + assert "multi_modal_data" not in batch.non_tensor_batch, f"request_id={request_id}: multi_modal_data should be empty but found in batch" + multi_modal_data = None + + # Check if this is a continuation request (interrupted request being resumed) + is_continued_request = batch.meta_info.get("is_continued_request", False) + continuation_mode = batch.meta_info.get("continuation_mode", False) + + if continuation_mode and "partial_output_tokens" in batch.meta_info: + logger.info(f"PROCESSING_PATH: request_id={request_id}, using CONTINUATION mode") + + # **ASSERTIONS for continued/interrupted requests** + assert is_continued_request, f"request_id={request_id}: continuation_mode=True but is_continued_request=False" + assert "original_prompt_length" in batch.meta_info, f"request_id={request_id}: continuation_mode requires original_prompt_length in meta_info" + assert "cumulative_partial_output_length" in batch.meta_info, f"request_id={request_id}: continuation_mode requires cumulative_partial_output_length in meta_info" + assert "migration_count" in batch.meta_info, f"request_id={request_id}: continuation_mode requires migration_count in meta_info" + + # Handle continuation mode - concatenate original prompt + partial output + partial_output_tokens = batch.meta_info["partial_output_tokens"] + + # **ATTENTION MASK DEBUG**: Log attention mask details before gather_unpadded_input_ids + logger.info(f"ATTENTION_MASK_DEBUG: request_id={request_id}, input_ids.shape={input_ids.shape}, attention_mask.shape={attention_mask.shape}") + logger.info(f"ATTENTION_MASK_DEBUG: request_id={request_id}, attention_mask.sum()={attention_mask.sum().item()}, attention_mask={attention_mask.tolist()}") + + logger.info(f"Continuation mode for request {request_id}: before gather_unpadded_input_ids: request_id={request_id}, input_ids.shape={input_ids.shape}, input_ids_first_10={input_ids[0][:10].tolist()}") + original_prompt_tokens = gather_unpadded_input_ids(input_ids=input_ids, attention_mask=attention_mask) + logger.info(f"Continuation mode for request {request_id}: after gather_unpadded_input_ids: request_id={request_id}, len(original_prompt_tokens[0])={ len(original_prompt_tokens[0])} original_prompt_tokens_first_10={original_prompt_tokens[0][:10] if original_prompt_tokens and len(original_prompt_tokens[0]) > 0 else 'empty'}") + + # **CRITICAL ASSERTION**: Verify gathered tokens match attention mask + assert len(original_prompt_tokens[0]) == attention_mask.sum().item(), f"request_id={request_id}: gathered tokens length {len(original_prompt_tokens[0])} != attention_mask sum {attention_mask.sum().item()}" + + # **ASSERTION**: Verify the input_ids contains the expected continued input + original_prompt_length = batch.meta_info["original_prompt_length"] + cumulative_partial_output_length = batch.meta_info["cumulative_partial_output_length"] + expected_input_length = original_prompt_length + cumulative_partial_output_length + actual_input_length = len(original_prompt_tokens[0]) + + # **CRITICAL LENGTH VERIFICATION** + logger.info(f"LENGTH_VERIFICATION: request_id={request_id}, original_prompt_length={original_prompt_length}, cumulative_partial_output_length={cumulative_partial_output_length}, expected_input_length={expected_input_length}, actual_input_length={actual_input_length}") + logger.info(f"LENGTH_VERIFICATION: request_id={request_id}, attention_mask.sum()={attention_mask.sum().item()}, input_ids.shape[1]={input_ids.shape[1]}") + + # **ASSERTION**: Verify the attention mask fix worked correctly + assert input_ids.shape[1] == attention_mask.shape[1], f"request_id={request_id}: input_ids length {input_ids.shape[1]} != attention_mask length {attention_mask.shape[1]}" + assert actual_input_length <= input_ids.shape[1], f"request_id={request_id}: gathered length {actual_input_length} > input_ids length {input_ids.shape[1]}" + + # **FIXED ASSERTION**: Now that we track effective lengths, this should match + assert actual_input_length == expected_input_length, f"request_id={request_id}: effective input_ids length mismatch. Expected {expected_input_length} (original_effective={original_prompt_length} + cumulative_partial={cumulative_partial_output_length}), got {actual_input_length}. Note: input_ids.shape[1]={input_ids.shape[1]}, attention_mask.sum()={attention_mask.sum().item()}" + + # **CRITICAL FIX**: Don't concatenate again! The input_ids already contains the continued input + # The scheduler has already concatenated: original_prompt + all_previous_partial_outputs + # We just need to use the input_ids as-is, not concatenate partial_output_tokens again + prompt_token_ids = original_prompt_tokens + + logger.info(f"Continuation mode for request {request_id}: " + f"original_prompt_length={original_prompt_length}, " + f"cumulative_partial_output_length={cumulative_partial_output_length}, " + f"current_partial_output_length={len(partial_output_tokens)}, " + f"input_ids_length={actual_input_length}") else: + # Normal mode + logger.info(f"PROCESSING_PATH: request_id={request_id}, using NORMAL mode") + + # **ATTENTION MASK DEBUG**: Log attention mask details for normal mode too + logger.info(f"ATTENTION_MASK_DEBUG: request_id={request_id}, input_ids.shape={input_ids.shape}, attention_mask.shape={attention_mask.shape}") + logger.info(f"ATTENTION_MASK_DEBUG: request_id={request_id}, attention_mask.sum()={attention_mask.sum().item()}, attention_mask={attention_mask.tolist()}") + + logger.info(f"Before gather_unpadded_input_ids: request_id={request_id}, input_ids.shape={input_ids.shape}, input_ids_first_10={input_ids[0][:10].tolist()}") prompt_token_ids = gather_unpadded_input_ids(input_ids=input_ids, attention_mask=attention_mask) - multi_modal_data = None - lora_requests = None - if self.is_lora: - batch_size = len(prompt_token_ids) - lora_int_ids = list(self.model.llm_engine.list_loras()) - if len(lora_int_ids) > 0: - lora_int_id = lora_int_ids[0] - lora_requests = [ - LoRARequest( - lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="dummy_lora_path" - ) - ] * batch_size - self.model.add_requests( - request_ids=[request_id], - prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params, - multi_modal_data=multi_modal_data, - lora_requests=lora_requests, - ) + logger.info(f"After gather_unpadded_input_ids: request_id={request_id}, len(prompt_token_ids[0])={ len(prompt_token_ids[0])} prompt_token_ids_first_10={prompt_token_ids[0][:10] if prompt_token_ids and len(prompt_token_ids[0]) > 0 else 'empty'}") + + # **ASSERTION**: Verify gathered tokens match attention mask for normal mode + assert len(prompt_token_ids[0]) == attention_mask.sum().item(), f"request_id={request_id}: gathered tokens length {len(prompt_token_ids[0])} != attention_mask sum {attention_mask.sum().item()}" + + # Debug logging - all in one line + decoded_prompt = self.tokenizer.decode(prompt_token_ids[0], skip_special_tokens=True) if prompt_token_ids else "" + + # Debug: Compare original input_ids vs processed prompt_token_ids + if prompt_token_ids and len(prompt_token_ids[0]) > 0: + original_decoded = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) + processed_decoded = self.tokenizer.decode(prompt_token_ids[0], skip_special_tokens=True) + + # Check if they're different + if original_decoded != processed_decoded: + logger.info(f"DECODE_MISMATCH: request_id={request_id}") + logger.info(f"DECODE_MISMATCH: original_decoded='{original_decoded[:200]}'") + logger.info(f"DECODE_MISMATCH: processed_decoded='{processed_decoded[:200]}'") + logger.info(f"DECODE_MISMATCH: original_input_ids={input_ids[0].tolist()}") + logger.info(f"DECODE_MISMATCH: processed_prompt_token_ids={prompt_token_ids[0]}") + else: + logger.info(f"DECODE_MATCH: request_id={request_id}, decoded prompts are identical") + + logger.info(f"ADD_REQUEST: request_id={request_id} | input_ids.shape={input_ids.shape} | attention_mask.shape={attention_mask.shape} | " + f"input_ids={input_ids.tolist()} | attention_mask sum ={sum(attention_mask)} | " + f"processed_prompt_token_ids={prompt_token_ids} | \ndecoded_prompt='{decoded_prompt}'") + + self.model.add_requests(request_ids=[request_id], + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + multi_modal_data=multi_modal_data) + logger.info(f"request {request_id} added") + elif command == GenerateRequestType.ABORT: request_id = batch.meta_info["request_id"] + assert request_id in self.request_metas + logger.info(f"{request_id=} abort command sent to backend engine, remove from request_metas") self.model.abort_request(request_id=request_id) + self.request_metas.pop(request_id) + + elif command == GenerateRequestType.INTERRUPT: + request_id = batch.meta_info.get("request_id", None) + target_leftover_cnt = batch.meta_info.get("target_leftover_cnt", None) + assert (request_id is None) ^ (target_leftover_cnt is None), f"they are exclusive but got {request_id=} {target_leftover_cnt=}" + if request_id: + assert request_id in self.request_metas, f"request_id {request_id} not in request_metas {self.request_metas.keys()}" + logger.info(f"interrupt request command sent to backend engine") + self.model.abort_request(request_id=request_id) + interrupted_rid_set.add(request_id) + + + if target_leftover_cnt: + # Check if we have the v1 engine + if hasattr(self.model.llm_engine, 'engine_core'): + # Use v1 engine's collective_rpc to call abort_to_target_requests_cnt + logger.info(f"Using v1 engine abort_to_target_requests_cnt with target={target_leftover_cnt}") + + # Use collective_rpc to call the method on the engine core + results = self.model.llm_engine.collective_rpc( + method='abort_to_target_requests_cnt', + args=(target_leftover_cnt,) + ) + # collective_rpc returns a list of results of interrupted requet id or None if no request was interrupted + if results is not None: + assert isinstance(results, list), f"result from collective_rpc should be a list, got {type(results)}" + for interrupted_rids in results: + assert interrupted_rids in self.request_metas, f"interrupted_rids {interrupted_rids} not in request_metas {self.request_metas.keys()}" + interrupted_rid_set.update(interrupted_rids) + logger.info(f"V1 engine interrupted {len(interrupted_rid_set)} requests {interrupted_rid_set}" ) + # Note: We can't track which specific requests were interrupted in v1 + logger.info(f"V1 engine processed interruption to target count {target_leftover_cnt}") + + else: + # Fallback for v0 engine - use original implementation + # interrupt requests up to the point of target_leftover_cnt, + # sort request by the overhead of migration and find the easiest to interrupt + + # Get current request counts from vLLM engine stats + stats = self.model.llm_engine._get_stats(scheduler_outputs=None) + + # Count requests in different queues + waiting_count = stats.num_waiting_sys + running_count = stats.num_running_sys + swapped_count = stats.num_swapped_sys + total_count = waiting_count + running_count + swapped_count + + # Calculate how many requests to interrupt + interrupt_count = total_count - target_leftover_cnt + + if interrupt_count <= 0: + logger.info(f"No interruption needed: total_count={total_count}, target_leftover={target_leftover_cnt}") + continue + + logger.info(f"Dynamic load balance: need to interrupt {interrupt_count} requests " + f"(waiting={waiting_count}, running={running_count}, swapped={swapped_count}, " + f"target_leftover={target_leftover_cnt})") + + # Get the scheduler to access request queues + scheduler = self.model.llm_engine.scheduler[0] + + # Get all requests from queues + waiting_requests = list(scheduler.waiting) + swapped_requests = list(scheduler.swapped) + running_requests = list(scheduler.running) + + # Merge swapped and running requests and sort by total length (shortest first) + swapped_and_running = [] + + # Add swapped requests + for seq_group in swapped_requests: + request_id = seq_group.request_id + assert request_id in self.request_metas, f"request_id {request_id} not found in request_metas buffer" + + # Calculate total sequence length (prompt + generated tokens) + total_length = 0 + for seq in seq_group.get_seqs(): + total_length += seq.get_len() + + swapped_and_running.append((request_id, total_length, 'swapped', seq_group)) + + # Add running requests + for seq_group in running_requests: + request_id = seq_group.request_id + assert request_id in self.request_metas, f"request_id {request_id} not found in request_metas buffer" + + # Calculate total sequence length (prompt + generated tokens) + total_length = 0 + for seq in seq_group.get_seqs(): + total_length += seq.get_len() + + swapped_and_running.append((request_id, total_length, 'running', seq_group)) + + # Sort by total length (ascending) - interrupt shortest sequences first + swapped_and_running.sort(key=lambda x: x[1]) + + # Concatenate all requests in priority order: waiting -> (swapped+running sorted by length) + all_requests_ordered = [] + all_requests_ordered.extend([(sg.request_id, 'waiting', sg) for sg in waiting_requests]) + all_requests_ordered.extend([(rid, status, sg) for rid, _, status, sg in swapped_and_running]) + + # Select requests to interrupt using while loop + requests_to_interrupt = [] + idx = 0 + + while len(requests_to_interrupt) < interrupt_count and idx < len(all_requests_ordered): + request_id, status, seq_group = all_requests_ordered[idx] + requests_to_interrupt.append(request_id) + + if status == 'running': + total_length = sum(seq.get_len() for seq in seq_group.get_seqs()) + logger.info(f"Selected {status} request {request_id} for interruption (total_length={total_length})") + else: + logger.info(f"Selected {status} request {request_id} for interruption") + + idx += 1 + + # Step 3: Abort the selected requests + for request_id in requests_to_interrupt: + logger.info(f"Interrupting request {request_id}") + self.model.abort_request(request_id=request_id) + interrupted_rid_set.add(request_id) + + elif command == GenerateRequestType.STOP: self.model.abort_request(request_id=list(self.request_metas.keys())) self.request_metas.clear() @@ -246,10 +534,67 @@ def start_server(self, data: DataProto, request_complete_callback): # model execute loop or there will be garbage output at next step. self.model.clear_unfinished_requests() self.running = False + self.unfinished_vllm_outputs.clear() return - vllm_outputs: List[RequestOutput] = self.model.fetch_output() - self.process_vllm_output(vllm_outputs=vllm_outputs, request_complete_callback=request_complete_callback) + finished_vllm_outputs, unfinished_vllm_outputs = self.model.fetch_output() + # add or update the buffer of unfinished request output + for request_output in unfinished_vllm_outputs: + + # if request in added_batch/not aborted, update the buffered the partial request output + if request_output.request_id in self.request_metas: + self.unfinished_vllm_outputs[request_output.request_id] = request_output + + # Log finished outputs for debugging + for f in finished_vllm_outputs: + # Collect all prompt information + token_prompt = "" + text_prompt = "" + token_ids = [] + + if hasattr(f, 'prompt_token_ids') and f.prompt_token_ids: + token_prompt = self.tokenizer.decode(f.prompt_token_ids, skip_special_tokens=True) + token_ids = f.prompt_token_ids + + if hasattr(f, 'prompt') and f.prompt: + text_prompt = f.prompt + + # Collect all outputs + outputs_info = [] + for i, output in enumerate(f.outputs): + output_text = self.tokenizer.decode(output.token_ids, skip_special_tokens=True) + outputs_info.append(f"Output_{i}[tokens={len(output.token_ids)}]: '{output_text}'") + + # Log everything in a single line + logger.info(f"FINISHED_OUTPUT: request_id={f.request_id}, finished={f.finished} | \n" + f"TOKEN_PROMPT: '{token_prompt}' \n| TOKEN_IDS: {token_ids} | \n" + f"TEXT_PROMPT: '{text_prompt}' \n| OUTPUTS: {' | '.join(outputs_info)}") + + self.handle_vllm_output(finished_vllm_outputs, interrupted_rid_set, request_complete_callback) + + + + + # + # for request_output in finished_vllm_outputs: + # # still in added_batch not aborted, process the request output + # if request_output.request_id in self.added_batch: + # logger.info( + # f"process_vllm_output: finished request from fetch_output and in added_batch, request_id {request_output.request_id}") + # self.process_vllm_output(vllm_outputs=request_output, + # request_complete_callback=request_complete_callback) + # self.unfinished_vllm_outputs.pop(request_output.request_id, None) + # self.added_batch.pop(request_output.request_id, None) + # + # # add or update the buffer of unfinished request output + # for request_output in unfinished_vllm_outputs: + # + # # if request in added_batch/not aborted, update the buffered the partial request output + # if request_output.request_id in self.added_batch: + # self.unfinished_vllm_outputs[request_output.request_id] = request_output + # + # self.handle_interrupted_requests(request_complete_callback) + def add_request(self, command, data: DataProto): self.command_queue.put((command, data)) @@ -265,11 +610,10 @@ async def async_generate(self, batch: DataProto, generation_config: Dict) -> tor attention_mask = batch.batch["attention_mask"] # left-padded attention_mask assert input_ids.size(0) == 1, f"async_generate: batch['input_ids'] must have exactly one batch dimension" - prompt_token_ids = gather_unpadded_input_ids(input_ids=input_ids, attention_mask=attention_mask) - # TODO meaningful request id? # async_generate如何实现abort_request request_id = random_uuid() + prompt_token_ids = gather_unpadded_input_ids(input_ids=input_ids, attention_mask=attention_mask) result_generator = self.model.generate( prompt=TokensPrompt(prompt_token_ids=prompt_token_ids[0]), sampling_params=sampling_params, @@ -302,27 +646,33 @@ def offload_states(self, include=None, non_blocking=False): torch.cuda.empty_cache() # 参数同步相关接口 - def setup_collective_group(self, model_update_name, comm_plan, backend="nccl"): + def setup_collective_group(self, comm_plan, backend="nccl"): self.model.setup_collective_group(comm_plan=comm_plan, backend=backend, rank_in_cluster=self.worker.rank) - def broadcast_parameter(self, model_update_name, src_pp_rank, dtype, shape, parameter_name, is_lora=False): - self.model.broadcast_parameter(src_pp_rank, dtype, shape, parameter_name, is_lora) + def broadcast_parameter(self, src_pp_rank, dtype, shape, parameter_name): + self.model.broadcast_parameter(src_pp_rank, dtype, shape, parameter_name) - def broadcast_bucket(self, model_update_name, src_pp_rank, meta_infos, bucket_size): + def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size): self.model.broadcast_bucket(src_pp_rank, meta_infos, bucket_size) - def update_parameter(self, model_update_name, parameter_name, weight, ranks_in_worker, is_lora=False): - self.model.update_parameter(parameter_name, weight, ranks_in_worker, is_lora) + def update_parameter(self, parameter_name, weight, ranks_in_worker): + self.model.update_parameter(parameter_name, weight, ranks_in_worker) - def update_parameter_in_bucket(self, model_update_name, meta_infos, buffer, ranks_in_worker): + def update_parameter_in_bucket(self, meta_infos, buffer, ranks_in_worker): self.model.update_parameter_in_bucket(meta_infos, buffer, ranks_in_worker) - def add_lora(self, peft_config): - self.model.add_lora(peft_config) - def gather_unpadded_input_ids(input_ids: torch.Tensor, attention_mask: torch.Tensor): + # Debug: Log input details + logger.info(f"GATHER_DEBUG: input_ids_shape={input_ids.shape}, attention_mask_shape={attention_mask.shape}") + logger.info(f"GATHER_DEBUG: input_ids={input_ids.tolist()}") + logger.info(f"GATHER_DEBUG: attention_mask sum ={sum(attention_mask)}") + gathered_input_ids = [ids[mask.bool()].tolist() for ids, mask in zip(input_ids, attention_mask)] + + # Debug: Log output details + logger.info(f"GATHER_DEBUG: gathered_input_ids={gathered_input_ids}") + return gathered_input_ids diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index cf8ad43c..44a8fa08 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -60,6 +60,15 @@ def initialize(self, pipeline_config): # current process is used as engine client when using vllm v1 engine, and # there is no chance to init cuda context. torch.cuda.init() + # Training scheduler (use_additional_prompts = True or default) + import pydevd_pycharm + + debug_port = 12345 + scheduler_type = "TRAINING" + self.logger.info(f"Connecting PyCharm debugger on port {debug_port}") + if os.getenv("PYCHARM", "0") == "1": + pydevd_pycharm.settrace('localhost', port=debug_port, stdoutToServer=True, stderrToServer=True, suspend=False) + self.logger.info(f"PyCharm debugger attached to {scheduler_type} scheduler on port {debug_port}") @register(dispatch_mode=Dispatch.DP_MP_DISPATCH_FIRST) def train_step(self, data: DataProto): @@ -356,7 +365,10 @@ def add_request(self, command, data: DataProto): if self.thread_server is not None: if not self.thread_server.is_alive(): raise Exception("thread server has stopped unexpectedly. check stderr for more info.") - output = DataProto(meta_info={"request_counts": len(self.response_call_back_fns)}) + output = DataProto(meta_info={"request_counts": len(self.response_call_back_fns), + "registered_requests_callbacks": list(self.response_call_back_fns.keys()) + }) + return output elif command == GenerateRequestType.ADD: assert "response_callback_fn" in data.meta_info, "response_callback_fn is not in data.meta_info" @@ -374,6 +386,8 @@ def add_request(self, command, data: DataProto): ] + self.tokenizer.additional_special_tokens_ids generation_config["pad_token_id"] = self.tokenizer.pad_token_id data.meta_info["generation_config"] = generation_config + self.logger.info( + f"worker add_request: {data.meta_info['request_id']} callback_fn: {data.meta_info.get('response_callback_fn', None)}") self.response_call_back_fns[data.meta_info["request_id"]] = data.meta_info.pop("response_callback_fn") self.strategy.add_request(command=command, data=data) return DataProto(meta_info={"request_counts": len(self.response_call_back_fns)}) @@ -381,8 +395,19 @@ def add_request(self, command, data: DataProto): def request_complete(self, data: DataProto): data.meta_info["eos_token_id"] = self.tokenizer.eos_token_id data.meta_info["pad_token_id"] = self.tokenizer.pad_token_id + fn = self.response_call_back_fns.get(data.meta_info["request_id"], None) + self.logger.info( + f"worker request_complete: pop from response_call_back_fns callback_fn for {data.meta_info['request_id']}: {fn}") + # tao perhaps not call it untill finishes. + response_call_back_fn = self.response_call_back_fns.pop(data.meta_info["request_id"]) - self.response_callback_refs.append(response_call_back_fn(data)) + call = response_call_back_fn(data) + # shall not block too long here + read_res, _ = ray.wait([call], timeout=5.0) + if len(read_res) == 0: + self.logger.warning(f"response callback on report_response for {data.meta_info['request_id']} is not finished in 5s, " + ) + self.response_callback_refs.append(call) class CriticWorker(Worker): diff --git a/roll/pipeline/rlvr/rlvr_pipeline.py b/roll/pipeline/rlvr/rlvr_pipeline.py index 4466664c..db1646ce 100644 --- a/roll/pipeline/rlvr/rlvr_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_pipeline.py @@ -20,31 +20,21 @@ from roll.models.model_providers import default_tokenizer_provider from roll.pipeline.base_pipeline import BasePipeline from roll.pipeline.rlvr.rlvr_config import RLVRConfig -from roll.pipeline.rlvr.utils import dump_rollout_to_specific_path from roll.utils.functionals import ( - RunningMoments, - agg_loss, compute_advantage, - compute_token_reward, - get_sample_level_mask, reduce_metrics, + RunningMoments, + get_sample_level_mask, reward_postprocess, + compute_token_reward, + agg_loss, ) from roll.utils.kl_controller import get_kl_controller from roll.utils.logging import get_logger from roll.utils.metrics.metrics_manager import MetricsManager - logger = get_logger() - - -def is_lora_training(pipeline_config: RLVRConfig) -> bool: - if pipeline_config.actor_train.model_args.lora_target is None: - return False - assert pipeline_config.actor_train.strategy_args.strategy_name == "deepspeed_train", ( - "LoRA only supports deepspeed_train" - ) - return True +test_sched = True def preprocess_dataset(dataset, prompt_len, encode_function, num_proc): @@ -123,7 +113,6 @@ class RLVRPipeline(BasePipeline): def __init__(self, pipeline_config: RLVRConfig): super().__init__(pipeline_config) self.pipeline_config = pipeline_config - self.is_lora = is_lora_training(self.pipeline_config) self.tokenizer = default_tokenizer_provider(model_args=self.pipeline_config.actor_train.model_args) @@ -133,7 +122,8 @@ def __init__(self, pipeline_config: RLVRConfig): print(f'load_dataset_paths: {chr(10)} {chr(10).join(dataset_paths)}') dataset = datasets.load_dataset('json', data_files=dataset_paths)['train'] - + print(f"dataset loaded: {dataset[:2]}") + # exit(0) self.val_dataset = None if self.pipeline_config.validation: val_dataset_paths = self.pipeline_config.validation.data_args.file_name @@ -201,21 +191,19 @@ def __init__(self, pipeline_config: RLVRConfig): worker_cls=self.pipeline_config.actor_train.worker_cls, resource_manager=self.resource_manager, worker_config=self.pipeline_config.actor_train, - ) + ) if not test_sched else None self.actor_infer: Any = Cluster( name=self.pipeline_config.actor_infer.name, worker_cls=self.pipeline_config.actor_infer.worker_cls, resource_manager=self.resource_manager, worker_config=self.pipeline_config.actor_infer, ) - # use unwrapped model as reference for lora training - if not self.is_lora: - self.reference: Any = Cluster( - name=self.pipeline_config.reference.name, - worker_cls=self.pipeline_config.reference.worker_cls, - resource_manager=self.resource_manager, - worker_config=self.pipeline_config.reference, - ) + self.reference: Any = Cluster( + name=self.pipeline_config.reference.name, + worker_cls=self.pipeline_config.reference.worker_cls, + resource_manager=self.resource_manager, + worker_config=self.pipeline_config.reference, + ) if not test_sched else None if self.pipeline_config.adv_estimator == "gae": self.critic: Any = Cluster( name=self.pipeline_config.critic.name, @@ -245,6 +233,7 @@ def __init__(self, pipeline_config: RLVRConfig): domain_batch_size = int(domain_ratios[domain] * self.pipeline_config.rollout_batch_size) accumulated += domain_batch_size generate_scheduler = DynamicSamplingScheduler.options( + name=f"DynamicSamplingScheduler-{domain}", scheduling_strategy=NodeAffinitySchedulingStrategy( node_id=ray.get_runtime_context().get_node_id(), soft=False, @@ -265,14 +254,16 @@ def __init__(self, pipeline_config: RLVRConfig): ) self.generate_schedulers[domain] = generate_scheduler self.domain_batch_size[domain] = domain_batch_size - - assert domain_batch_size < len(self.domain_datasets[domain]), (f"domain_batch_size {domain_batch_size} must be " - f"less than the number of domain datasets {len(self.domain_datasets[domain])}") + # skip for using fixed dataset item + # assert domain_batch_size < len(self.domain_datasets[domain]), (f"domain_batch_size {domain_batch_size} must be " + # f"less than the number of domain datasets {len(self.domain_datasets[domain])}") if self.val_dataset: val_pipeline_config = copy.deepcopy(self.pipeline_config) val_pipeline_config.use_additional_prompts = False self.val_generate_scheduler = DynamicSamplingScheduler.options( + name="DynamicSamplingScheduler-validation", + scheduling_strategy=NodeAffinitySchedulingStrategy( node_id=ray.get_runtime_context().get_node_id(), soft=False, @@ -296,20 +287,23 @@ def __init__(self, pipeline_config: RLVRConfig): refs.extend(self.actor_infer.initialize(pipeline_config=self.pipeline_config, blocking=False)) ray.get(refs) - if not self.is_lora: + if not test_sched: refs.extend(self.reference.initialize(pipeline_config=self.pipeline_config, blocking=True)) + refs = [] for key, cluster in self.rewards.items(): refs.extend(cluster.initialize(pipeline_config=self.pipeline_config, blocking=False)) ray.get(refs) refs: List[ray.ObjectRef] = [] - refs.extend(self.actor_train.initialize(pipeline_config=self.pipeline_config, blocking=False)) + if not test_sched: + refs.extend(self.actor_train.initialize(pipeline_config=self.pipeline_config, blocking=False)) if self.pipeline_config.adv_estimator == "gae": refs.extend(self.critic.initialize(pipeline_config=self.pipeline_config, blocking=False)) ray.get(refs) - self.set_model_update_pair( + if not test_sched: + self.set_model_update_pair( src_cluster=self.actor_train, tgt_cluster=self.actor_infer, frequency=self.pipeline_config.actor_train.model_update_frequency, @@ -318,7 +312,8 @@ def __init__(self, pipeline_config: RLVRConfig): if self.pipeline_config.adv_estimator == "gae": self.set_checkpoint_clusters(self.actor_train, self.critic) else: - self.set_checkpoint_clusters(self.actor_train) + if not test_sched: + self.set_checkpoint_clusters(self.actor_train) self.running = {} for domain in self.rewards.keys(): @@ -348,14 +343,18 @@ def run(self): # 先model update,resume时不需要保存infer cluster的状态 if self.pipeline_config.adv_estimator == "gae": self.critic.offload_states(blocking=True) - self.actor_train.offload_states(blocking=True) - with Timer(name="step_model_update", logger=None) as step_model_update_timer: - model_update_metrics: Dict = self.model_update(global_step) - metrics_mgr.add_metrics(model_update_metrics) - metrics_mgr.add_metric("time/step_model_update", step_model_update_timer.last) + if not test_sched: + self.actor_train.offload_states(blocking=True) - if self.val_dataset and global_step % self.pipeline_config.eval_steps == 0: + with Timer(name="step_model_update", logger=None) as step_model_update_timer: + model_update_metrics: Dict = self.model_update(global_step) + metrics_mgr.add_metrics(model_update_metrics) + metrics_mgr.add_metric("time/step_model_update", step_model_update_timer.last) + + + # if self.val_dataset and global_step % self.pipeline_config.eval_steps == 0: + if False: with Timer(name="val_step", logger=None) as val_step_timer: val_metrics = self.val() metrics_mgr.add_metrics(val_metrics) @@ -370,6 +369,9 @@ def run(self): ) as step_generate_timer: domain_batches = {} batch.meta_info["generation_config"] = self.actor_infer.worker_config.generating_args.to_dict() + + # Tao: To keep constant seed for determinsitic sampling + batch.meta_info["generation_config"]['seed'] = 123 self.actor_infer.start_server(data=DataProto(meta_info=batch.meta_info)) for reward_cluster in self.rewards.values(): reward_cluster.load_states() @@ -385,7 +387,6 @@ def run(self): ) domain_batches[domain] = domain_batch generate_output = DataProto.concat([domain_batch for domain_batch in domain_batches.values()]) - dump_rollout_to_specific_path(self.pipeline_config.rollout_dump_dir, global_step, generate_output, self.tokenizer) generate_output.meta_info.pop("is_offload_states", None) for reward_cluster in self.rewards.values(): @@ -395,22 +396,22 @@ def run(self): metrics_mgr.add_metric("time/step_generate", step_generate_timer.last) batch = generate_output + + + + if test_sched: + logger.info("test_sched is True, skip calculation of reference log_probs and old log_probs and training, exiting...") + + return with Timer(name="cal_ref_log_probs", logger=None) as cal_ref_log_probs_timer: - if self.is_lora: - batch.meta_info["disable_adapter"] = True - batch.meta_info["is_offload_states"] = False - ref_log_probs = self.actor_train.compute_log_probs(batch, blocking=True) - else: - ref_log_probs = self.reference.compute_log_probs(batch, blocking=True) + ref_log_probs = self.reference.compute_log_probs(batch, blocking=True) metrics_mgr.add_reduced_metrics(ref_log_probs.meta_info.pop("metrics", {})) ref_log_probs.rename(old_keys="log_probs", new_keys="ref_log_probs") batch = batch.union(ref_log_probs) metrics_mgr.add_metric("time/ref_log_probs_values", cal_ref_log_probs_timer.last) with Timer(name="cal_old_log_probs_values", logger=None) as cal_old_logpb_timer: - if self.is_lora: - batch.meta_info["disable_adapter"] = False batch.meta_info["is_offload_states"] = False if self.pipeline_config.adv_estimator == "gae": values_refs: List[ray.ObjectRef] = self.critic.compute_values(batch, blocking=False) diff --git a/roll/third_party/vllm/__init__.py b/roll/third_party/vllm/__init__.py index b6b63dae..f1c68a4e 100644 --- a/roll/third_party/vllm/__init__.py +++ b/roll/third_party/vllm/__init__.py @@ -6,7 +6,7 @@ if "0.7.3" in vllm.__version__: from roll.third_party.vllm.vllm_0_7_3.llm import Llm073 LLM = Llm073 -elif "0.8.4" in vllm.__version__: +elif "0.8.4" in vllm.__version__ or "0.8.5.dev" in vllm.__version__ : from roll.third_party.vllm.vllm_0_8_4.llm import Llm084 from roll.third_party.vllm.vllm_0_8_4.v1.async_llm import AsyncLLM084 LLM = Llm084 diff --git a/roll/third_party/vllm/vllm_0_8_4/llm.py b/roll/third_party/vllm/vllm_0_8_4/llm.py index 36077a6e..ffb97c8b 100644 --- a/roll/third_party/vllm/vllm_0_8_4/llm.py +++ b/roll/third_party/vllm/vllm_0_8_4/llm.py @@ -1,20 +1,23 @@ import os +from typing import Iterable, Optional, List, Dict, Any, Union import queue -import time -from typing import Any, Dict, Iterable, List, Optional, Union +import time +import torch import cloudpickle -import torch -from vllm import LLM, EngineArgs, SamplingParams, envs -from vllm.config import CompilationConfig -from vllm.engine.arg_utils import HfOverrides, PoolerConfig, TaskOption -from vllm.lora.request import LoRARequest + +from vllm import LLM, SamplingParams, EngineArgs, envs from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter +from vllm.config import CompilationConfig +from vllm.engine.arg_utils import (HfOverrides, PoolerConfig, TaskOption) from roll.third_party.vllm.vllm_0_8_4.llm_engine import LLMEngine084 from roll.utils.send_recv_utils import SendBucketManager +from roll.utils.logging import get_logger + +logger = get_logger() class Llm084(LLM): @@ -129,6 +132,8 @@ def offload_states(self, level=1): def fetch_output(self): output_list = [] + unfinished_output_list = [] + # simulating non blocking semantic when using v1 engine if envs.VLLM_USE_V1: try: @@ -141,7 +146,10 @@ def fetch_output(self): for request_output in request_outputs: if request_output.finished: output_list.append(request_output) - return output_list + else: + unfinished_output_list.append(request_output) + + return output_list, unfinished_output_list def get_num_waiting(self): stats = self.llm_engine._get_stats(scheduler_outputs=None) @@ -153,7 +161,6 @@ def add_requests( request_ids: List[int] | None, sampling_params: SamplingParams, multi_modal_data: List[int] | None, - lora_requests: List[LoRARequest] | None, ): assert len(prompt_token_ids) == len(request_ids) if multi_modal_data: @@ -161,17 +168,16 @@ def add_requests( for i, (token_ids, request_id)in enumerate(zip(prompt_token_ids, request_ids)): if request_id is None: request_id = next(self.request_counter) - lora_request = lora_requests[i] if lora_requests is not None else None if multi_modal_data: # in v1, input_preprocessor is in engine.processor processor = getattr(self.llm_engine, "processor", None) input_preprocessor = processor.input_preprocessor if processor else self.llm_engine.input_preprocessor preprocessed_inputs = input_preprocessor.preprocess( prompt={"prompt_token_ids": token_ids, "multi_modal_data": multi_modal_data[i]}, - lora_request=lora_request, + lora_request=None, prompt_adapter_request=None, ) - # in v1, engine does not use a input_processor + # in v1, engine does not use a input_processor processed_inputs = ( self.llm_engine.input_processor(preprocessed_inputs) if hasattr(self.llm_engine, "input_processor") @@ -182,16 +188,21 @@ def add_requests( "type": "token", "prompt_token_ids": token_ids } - self.llm_engine._add_processed_request( - request_id=request_id, - processed_inputs=processed_inputs, - params=sampling_params, - arrival_time=time.time(), - lora_request=lora_request, - prompt_adapter_request=None, - ) + + logger.info(f"llm engine add request {request_id} with processed_inputs length {len(processed_inputs['prompt_token_ids'])} sampling_params {sampling_params} token ids {processed_inputs['prompt_token_ids']}") + # assert len(processed_inputs["prompt_token_ids"]) < 200, "my test should not exceed 200 tokens, please check the input preparetion logic perhaps of interrupted requests" + self.llm_engine._add_processed_request(request_id=request_id, + processed_inputs=processed_inputs, + params=sampling_params, + arrival_time=time.time(), + lora_request=None, + prompt_adapter_request=None + ) def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + + logger.info(f"Aborting requests for output_processor: {request_id}") + self.llm_engine.abort_request(request_id) def clear_unfinished_requests(self): @@ -226,6 +237,3 @@ def update_parameter_in_bucket(self, meta_infos, buffer, ranks_in_worker): # Newer version of vllm support efficient serilization of torch.Tensor. buffer = buffer.cpu().tolist() self.collective_rpc(method="update_parameter_in_bucket", args=(meta_infos, buffer, ranks_in_worker)) - - def add_lora(self, *args, **kwargs): - self.collective_rpc(method="add_lora", args=args, kwargs=kwargs) diff --git a/roll/third_party/vllm/vllm_0_8_4/v1/llm_engine.py b/roll/third_party/vllm/vllm_0_8_4/v1/llm_engine.py index a9a3fc26..8280a204 100644 --- a/roll/third_party/vllm/vllm_0_8_4/v1/llm_engine.py +++ b/roll/third_party/vllm/vllm_0_8_4/v1/llm_engine.py @@ -197,8 +197,11 @@ def _add_processed_request( priority: int = 0, ) -> None: if isinstance(params, SamplingParams): - params.output_kind = RequestOutputKind.FINAL_ONLY + # (tao) use cumulative output for roll for collecting partial output + # params.output_kind = RequestOutputKind.FINAL_ONLY + params.output_kind = RequestOutputKind.CUMULATIVE + request = self.processor.custom_process_inputs(request_id, processed_inputs, params, arrival_time, lora_request, trace_headers, @@ -211,9 +214,11 @@ def _add_processed_request( # Make a new RequestState and queue. self.output_processor.add_request(request, None, 0) # Add the request to EngineCore. + logger.info(f"added request to vllm engine {request=}") self.engine_core.add_request(request) return + assert False, "Tao assume not supported for now" # Fan out child requests (for n>1). parent_req = ParentRequest(request_id, params) for idx in range(n): diff --git a/roll/utils/functionals.py b/roll/utils/functionals.py index a7fad236..b08202d5 100644 --- a/roll/utils/functionals.py +++ b/roll/utils/functionals.py @@ -6,13 +6,17 @@ import torch import torch.nn.functional as F from tensordict import TensorDict +from torch import Tensor from roll.pipeline.rlvr.rlvr_config import RLVRConfig from roll.utils.kl_controller import AdaptiveKLController from roll.utils.logging import get_logger +import logging + logger = get_logger() +logger.setLevel(logging.DEBUG) def tensor_to_cpu_visitor(obj, path): @@ -728,6 +732,7 @@ class GenerateRequestType(enum.Enum): ABORT = enum.auto() STOP = enum.auto() ALIVE_CHECK = enum.auto() + INTERRUPT = enum.auto() def postprocess_generate( @@ -735,106 +740,115 @@ def postprocess_generate( output: torch.Tensor, num_return_sequences, sequence_length, + canonical_prompt_length: int, eos_token_id, pad_token_id, fill_eos_token=False, ) -> "DataProto": from roll.distributed.scheduler.protocol import DataProto + output_ids = output + output_ids = pad_to_length(output_ids, sequence_length, pad_token_id) + if fill_eos_token: # yali: 如果output最后一个token不是pad_token_id,则替换成eos_token_id, # TODO: 需要消融这个变化的影响 - last_token_index = output.size(1) - 1 - need_replace_mask = output[:, last_token_index] != pad_token_id - output[need_replace_mask, last_token_index] = eos_token_id - - input_ids = prompts.batch["input_ids"] # (bs, prompt_length) - attention_mask = prompts.batch["attention_mask"] # left-padded attention_mask - prompt_id = prompts.batch.get("prompt_id", None) - - # input_batch_size * num_return_sequences - output_batch_size = output.size(0) - input_batch_size = input_ids.size(0) - prompt_length = input_ids.size(1) - - output = pad_to_length(output, sequence_length, pad_token_id) - - assert output.shape[1] == sequence_length, f"output shape {output.shape} != {sequence_length}" - - prompt = output[:, :prompt_length].clone() # (bs, prompt_length) - response = output[:, prompt_length:].clone() # (bs, response_length) - - attention_mask = ( - attention_mask.unsqueeze(1).repeat(1, num_return_sequences, 1).view(output_batch_size, prompt_length) - ) - response_mask = get_pad_mask(response_id=response, pad_token=pad_token_id, dtype=attention_mask.dtype) - attention_mask = torch.cat((attention_mask, response_mask), dim=-1) - - position_ids = prompts.batch["position_ids"] - # if is_num_return_sequences_expand=True, num_return_sequences here equals 1 - if position_ids.dim() == 3: # qwen2vl mrope, maybe can support in other ways - position_ids = ( - position_ids.unsqueeze(1) - .repeat(1, num_return_sequences, 1, 1) - .view(output_batch_size, *position_ids.shape[-2:]) - ) - delta_position_id = torch.arange(1, (sequence_length - prompt_length) + 1, device=position_ids.device) - delta_position_id = delta_position_id.view(1, 1, -1).expand(output_batch_size, 3, -1) - response_position_ids = position_ids[..., -1:] + delta_position_id - # left padding for prompt and right padding for response, to be converted - # to right padding which is consistent with output - output_position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - - assert attention_mask.any(dim=1).all(), f"has all 0 attention_mask, {attention_mask} {input_ids}" - first_one = attention_mask.float().argmax(dim=1) - new_response_mask = torch.zeros_like(attention_mask) # response mask for cat input_ids - for i in range(output_batch_size): - shift = first_one[i].item() - if shift > 0: - output[i, :-shift] = output[i, shift:].clone() + last_token_index = output_ids.size(1) - 1 + need_replace_mask = output_ids[:, last_token_index] != pad_token_id + output_ids[need_replace_mask, last_token_index] = eos_token_id + + # Extract batch info from the input prompts + # Note: `prompts` can be a single DataProto for a normal batch, or a list of + # DataProto for a batch containing interrupted requests. We handle this later. + is_continued_request = prompts.meta_info.get("is_continued_request", False) + batch_size = output_ids.shape[0] + device = output_ids.device + + # --- Determine Prompt Length and Pad to Canonical Shape --- + # This is the most complex part. For normal requests, all prompts have the + # same canonical length. For interrupted requests, they have varying lengths + # and must be padded to the canonical length before being combined. + + # This logic assumes a heterogeneous batch has already been handled upstream + # and that for this function call, all prompts either are or will be padded + # to the canonical_prompt_length. + + if is_continued_request: + # For an interrupted request, the 'prompts' object contains the original, + # unpadded prompt. We must pad or truncate it to the canonical length. + prompt_tensor = prompts.batch["input_ids"] + original_prompt_length = prompt_tensor.shape[1] + original_attention_mask = prompts.batch.get("attention_mask", prompt_tensor.ne(pad_token_id)) + + if original_prompt_length > canonical_prompt_length: + # Prompt is longer than canonical, truncate from the left. + prompt = prompt_tensor[:, -canonical_prompt_length:] + attention_mask = original_attention_mask[:, -canonical_prompt_length:] else: - output[i, :] = output[i, :].clone() - valid_length = attention_mask[i].sum().int().item() - response_length = response_mask[i].sum().int().item() - attention_mask[i][:valid_length] = 1 - attention_mask[i][valid_length:] = 0 - new_response_mask[i][valid_length - response_length : valid_length] = 1 - if position_ids.dim() == 3 and shift > 0: - # shift as output to convert to right padding - # NOTE: left shift without clear right might lead to unclean values - # in right part, which especially is the case when using long prompt - # length and short response length. This usually makes no effect if - # mask is right, while it might make trouble to for multi-modal model - # like Qwen2-vl, since extra image_token would be left which might - # cause error: Image features and image tokens do not match - output_position_ids[i, ..., :-shift] = output_position_ids[i, ..., shift:].clone() - # only clean in VLM(qwen2-vl) to make no effect on LLM - if prompt_length > response_length: - output[i, -shift:] = pad_token_id - - prompt_mask = (attention_mask == 1) & (new_response_mask == 0) - if position_ids.dim() == 3: - position_ids = output_position_ids - else: # normal position_ids - position_ids = torch.clip(torch.cumsum(attention_mask, dim=-1) - 1, min=0, max=None) - batch = TensorDict( - { + # Prompt is shorter or equal, pad on the left to match canonical length. + pad_left = canonical_prompt_length - original_prompt_length + prompt = F.pad(prompt_tensor, (pad_left, 0), value=pad_token_id) + attention_mask = F.pad(original_attention_mask, (pad_left, 0), value=0) + else: + # For a normal request, the prompt is already at the canonical length. + prompt = prompts.batch["input_ids"] + attention_mask = prompts.batch["attention_mask"] + assert prompt.shape[1] == canonical_prompt_length, \ + f"Normal prompt length {prompt.shape[1]} must equal canonical {canonical_prompt_length}" + + # --- Final Shape Assertions (Pre-Computation) --- + assert prompt.shape[1] == canonical_prompt_length + assert attention_mask.shape[1] == canonical_prompt_length + assert output_ids.shape[1] == sequence_length + + # --- Construct Final Tensors --- + response = output_ids[:, canonical_prompt_length:] + canonical_response_length = sequence_length - canonical_prompt_length + + # --- Sanity Checks --- + assert prompt.shape[0] == response.shape[0], "Batch size mismatch between prompt and response" + assert response.shape[1] == canonical_response_length, "Response length is incorrect" + assert prompt.shape[1] + response.shape[1] == sequence_length, "Prompt and response lengths do not sum to sequence length" + + # --- Create Final Masks and Tensors --- + final_input_ids = torch.cat((prompt, response), dim=1) + + response_mask = response.ne(pad_token_id).to(attention_mask.dtype) + final_attention_mask = torch.cat((attention_mask, response_mask), dim=-1) + + # For continued requests, the attention mask might have a gap of zeros between + # the end of the prompt and the start of the new response. This fills that gap. + if is_continued_request: + for i in range(final_attention_mask.shape[0]): + valid_length = final_attention_mask[i].sum().int().item() + final_attention_mask[i][:valid_length] = 1 + final_attention_mask[i][valid_length:] = 0 + + position_ids = torch.arange(0, sequence_length, dtype=torch.long, device=device).unsqueeze(0) + prompt_mask_full = torch.arange(0, sequence_length, device=device)[None, :] < canonical_prompt_length + response_mask_full = ~prompt_mask_full + + # --- Final Shape Assertions (Post-Computation) --- + assert final_input_ids.shape == (batch_size, sequence_length) + assert final_attention_mask.shape == (batch_size, sequence_length) + assert position_ids.expand(batch_size, -1).shape == (batch_size, sequence_length) + assert prompt_mask_full.expand(batch_size, -1).shape == (batch_size, sequence_length) + assert response_mask_full.expand(batch_size, -1).shape == (batch_size, sequence_length) + + batch_dict = { "prompts": prompt, "responses": response, - "input_ids": output, # right pad - "attention_mask": attention_mask, # right pad - "position_ids": position_ids, - "prompt_mask": prompt_mask, - "response_mask": new_response_mask, # right pad, response tokens - }, - batch_size=output_batch_size, + "input_ids": final_input_ids, + "attention_mask": final_attention_mask, + "position_ids": position_ids.expand(batch_size, -1), + "prompt_mask": prompt_mask_full.expand(batch_size, -1), + "response_mask": response_mask_full.expand(batch_size, -1), + } + + return DataProto( + meta_info=prompts.meta_info, + batch=TensorDict(batch_dict, batch_size=[batch_size]), ) - if prompt_id is not None: - prompt_id = ( - prompt_id.squeeze().unsqueeze(1).repeat(1, num_return_sequences).view(output_batch_size, -1).squeeze(-1) - ) - batch["prompt_id"] = prompt_id - return DataProto(batch=batch) def get_dist_info_from_comm_plan(comm_plan, rank_in_cluster, rank_in_worker): diff --git a/tests/pipeline/rlvr_megatron_config_2A100.yaml b/tests/pipeline/rlvr_megatron_config_2A100.yaml new file mode 100644 index 00000000..fd7488a8 --- /dev/null +++ b/tests/pipeline/rlvr_megatron_config_2A100.yaml @@ -0,0 +1,237 @@ +defaults: + - ../../examples/config/deepspeed_zero@_here_ + - ../../examples/config/deepspeed_zero2@_here_ + - ../../examples/config/deepspeed_zero3@_here_ + - ../../examples/config/deepspeed_zero3_cpuoffload@_here_ + +hydra: + run: + dir: . + output_subdir: null + +exp_name: "rlvr_megatron_test" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +system_envs: + USE_MODELSCOPE: 'false' + PYCHARM: '0' + + +checkpoint_config: + type: file_system + output_dir: ./output/checkpoint + +#track_with: wandb +#tracker_kwargs: +# api_key: your_api_key +# project: roll-rlvr +# log_dir: debug +# tags: +# - roll +# - rlvr +# - debug + + +num_gpus_per_node: 2 +# max_steps is used for setting the learning rate etc better to not change it +max_steps: 500 + +# if early_stop_steps > 0, max_steps will be overridden by early_stop_steps for actual execution +early_stop_steps: 2 + +save_steps: 1000 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + +# rollout_batch_size: 1 +rollout_batch_size: 8 +prompt_length: 1024 +response_length: 512 + +num_return_sequences_in_group: 1 +ppo_epochs: 1 +value_clip: 0.5 +reward_clip: 10 +advantage_clip: 2.0 +whiten_advantages: true +init_kl_coef: 0.1 +adv_estimator: "reinforce" + + +is_use_additional_prompts: false +reward_filter_mean_threshold: 0.3 +max_running_requests: 256 +is_num_return_sequences_expand: false +max_additional_running_prompts: 16 +generate_redundancy_num: 0 + +# +#pretrain: Qwen/Qwen2.5-0.5B-Instruct +#reward_pretrain: Qwen/Qwen2.5-0.5B-Instruct + +# pretrain: Qwen/Qwen2.5-7B-Instruct +# reward_pretrain: Qwen/Qwen2.5-7B-Instruct + +pretrain: Mxode/NanoLM-0.3B-Instruct-v1.1 +reward_pretrain: Mxode/NanoLM-0.3B-Instruct-v1.1 + +validation: + data_args: + template: qwen2_5 + file_name: + - data/math_benchmarks.jsonl + + generating_args: + top_p: 0.6 + top_k: 50 + num_beams: 1 + temperature: 0.0 # for deterministic generation + num_return_sequences: 1 + eval_steps: 10 + +actor_train: + model_args: + flash_attn: fa2 + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + # this two params decides the actual max_steps for training + # keep small to make max_steps work + per_device_train_batch_size: 1 + gradient_accumulation_steps: 1 + warmup_steps: 50 + num_train_epochs: 50 + data_args: + template: qwen2_5 + file_name: + # - /data/oss_bucket_0/rl_examples/data/code_KodCode_data_hard.jsonl +# - /data/oss_bucket_0/rl_examples/data/llm_judge_Multi-subject-RLVR_deal_new.jsonl + - data/test_interrupt.jsonl # write a report to test interrupt +# - data/math_deepmath_deal.jsonl + prompt: instruction + interleave_probs: "1.0" + max_samples: 6400 + preprocessing_num_workers: 16 + domain_interleave_probs: + math_rule: 1.0 + # code_sandbox: 0.4 +# llm_judge: 0.1 + strategy_args: + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + context_parallel_size: 1 + overlap_grad_reduce: false + use_distributed_optimizer: true + device_mapping: list(range(0,2)) + infer_batch_size: 4 + +actor_infer: + model_args: + flash_attn: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0 # for deterministic generation + # temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + data_args: + template: qwen2_5 + strategy_args: + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.6 + block_size: 16 + max_model_len: 6000 + max_num_seqs: 5 + load_format: "auto" # by default, it is "dummy" that generate nonsense text + device_mapping: list(range(0,2)) + infer_batch_size: 16 + world_size: 2 + num_gpus_per_worker: 1 + system_envs: + VLLM_USE_V1: '1' + +reference: + model_args: + flash_attn: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + data_args: + template: qwen2_5 + strategy_args: + strategy_name: megatron_infer + strategy_config: + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + device_mapping: list(range(0,2)) + infer_batch_size: 8 + +rewards: + math_rule: + worker_cls: roll.pipeline.rlvr.rewards.math_rule_reward_worker.MathRuleRewardWorker + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2_5 + tag_included: [cn_k12, orca_math, olympiads, gsm8k, math, aops_forum] + world_size: 2 + infer_batch_size: 1 + query_filter_config: + type: mean_filter + filter_args: + threshold_up: 0.9 + threshold_down: 0.1 + # code_sandbox: + # use_local: true + # worker_cls: roll.pipeline.rlvr.rewards.code_sandbox_reward_worker.CodeSandboxRewardWorker + # tag_included: [assert, input, livecodebench_random100] + # model_args: + # model_name_or_path: ${reward_pretrain} + # data_args: + # template: qwen2_5 + # world_size: 8 + # infer_batch_size: 1 + # query_filter_config: + # type: std_filter + # filter_args: + # std_threshold: 0 +# llm_judge: +# # NOTE: llm as judge 也需要gpu, 不能和actor infer共享gpu +# worker_cls: roll.pipeline.rlvr.rewards.llm_judge_reward_worker.LLMJudgeRewardWorker +# judge_prompt: Qwen2.5-7B-Instruct-RLVR-prompt +# judge_model_type: inference +# tag_included: [reference] +# model_args: +# model_name_or_path: AI-ModelScope/Qwen2.5-7B-Instruct-RLVR +# flash_attn: fa2 +# disable_gradient_checkpointing: true +# dtype: bf16 +# model_type: trl +# generating_args: +# max_new_tokens: 100 +# top_p: 0.8 +# top_k: 50 +# num_beams: 1 +# temperature: 0.8 +# num_return_sequences: 1 +# data_args: +# template: qwen2_5 +# strategy_args: +# strategy_name: hf_infer +# strategy_config: null +# device_mapping: list(range(7,8)) +# infer_batch_size: 1 \ No newline at end of file diff --git a/tests/pipeline/test_rlvr_pipeline.py b/tests/pipeline/test_rlvr_pipeline.py index 86ce7d9d..b82e3ced 100644 --- a/tests/pipeline/test_rlvr_pipeline.py +++ b/tests/pipeline/test_rlvr_pipeline.py @@ -16,13 +16,28 @@ def make_ppo_config(): - config_path = "." config_name = args.config_name + print(f"DEBUG: Loading config_name = '{config_name}'") initialize(config_path=config_path) cfg = compose(config_name=config_name) - ppo_config = from_dict(data_class=RLVRConfig, data=OmegaConf.to_container(cfg, resolve=True)) + + # Debug: Check what's in the raw config + print(f"DEBUG: Raw cfg.adv_estimator = '{cfg.get('adv_estimator', 'NOT_FOUND')}'") + print(f"DEBUG: Raw cfg keys: {list(cfg.keys())}") + + # Convert to container + cfg_container = OmegaConf.to_container(cfg, resolve=True) + print(f"DEBUG: Container adv_estimator = '{cfg_container.get('adv_estimator', 'NOT_FOUND')}'") + + ppo_config = from_dict(data_class=RLVRConfig, data=cfg_container) + print(f"DEBUG: Final ppo_config.adv_estimator = '{ppo_config.adv_estimator}'") + + # TEMPORARY FIX: Explicitly set adv_estimator from config if dacite is not working + if 'adv_estimator' in cfg_container: + print(f"DEBUG: Manually overriding adv_estimator to '{cfg_container['adv_estimator']}'") + ppo_config.adv_estimator = cfg_container['adv_estimator'] return ppo_config @@ -33,8 +48,8 @@ def test_make_ppo_config(): def test_ppo_pipeline(): - ppo_config = make_ppo_config() + print(f"DEBUG: After loading config, adv_estimator = '{ppo_config.adv_estimator}'") init() @@ -43,6 +58,8 @@ def test_ppo_pipeline(): pipeline.run() + print("RLVR Pipeline test completed successfully.") + if __name__ == "__main__": test_ppo_pipeline()