Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[submodule "third_party/webshop-minimal"]
path = third_party/webshop-minimal
url = https://github.com/ZihanWang314/webshop-minimal.git
[submodule "third_party/vllm"]
path = third_party/vllm
url = https://github.com/taoluo/vllm.git
branch = roll
1 change: 1 addition & 0 deletions data/test_interrupt.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"id": "1", "source": "deepmath_103k", "difficulty": "4.5", "prompt": "You are a senior systems researcher. Draft a 4,000-word white paper titled “post-training systems for RLHF” with these sections: abstract (≤150 words), introduction, related work, system architecture (with numbered sub-sections), evaluation methodology, experimental results (include tables), limitations, and future work. Use formal academic tone, cite at least eight landmark papers inline (APA), and end with a concise conclusion.", "messages": "[{\"role\": \"system\", \"content\": \"Please reason step by step, and put your final answer within \\\\boxed{}.\"}, {\"role\": \"user\", \"content\": \"You are a senior systems researcher. Draft a 4,000-word white paper titled “post-training systems for RLHF” with these sections: abstract (≤150 words), introduction, related work, system architecture (with numbered sub-sections), evaluation methodology, experimental results (include tables), limitations, and future work. Use formal academic tone, cite at least eight landmark papers inline (APA), and end with a concise conclusion.\"}]", "ground_truth": "1", "case_type": "", "test_case_function": "", "test_cases": "", "tag": "deepmath_103k"}
24 changes: 24 additions & 0 deletions roll/datasets/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from transformers import DataCollatorForSeq2Seq, PreTrainedTokenizerBase, ProcessorMixin, BatchFeature
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
from transformers.utils import PaddingStrategy
from roll.utils.logging import get_logger

logger = get_logger()


def collate_fn_to_dict_list(data_list: list[dict]) -> dict:
Expand Down Expand Up @@ -98,6 +101,20 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
padded_features = [{k: v for k, v in feature.items() if k in self.padded_keys} for feature in features]
un_padded_features = [{k: v for k, v in feature.items() if k not in self.padded_keys} for feature in features]

# Debug: Log the input features
logger.info(f"COLLATOR_DEBUG: Processing {len(features)} features")
for i, feature in enumerate(features):
if 'input_ids' in feature:
input_ids = feature['input_ids']
logger.info(f"COLLATOR_DEBUG: Feature_{i}: input_ids_len={len(input_ids)}, input_ids_first_10={input_ids[:10]}")

# Log any text content for comparison
for key in ['prompt', 'text', 'messages', 'ground_truth']:
if key in feature:
text_data = feature[key]
sample_text = str(text_data)[:100] if text_data else "None"
logger.info(f"COLLATOR_DEBUG: Feature_{i}: {key}_sample='{sample_text}'")

batch = pad_without_fast_tokenizer_warning(
self.tokenizer,
padded_features,
Expand All @@ -106,6 +123,13 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)

# Debug: Log the output batch
if 'input_ids' in batch:
logger.info(f"COLLATOR_DEBUG: Output batch: input_ids_shape={batch['input_ids'].shape}")
for i in range(len(batch['input_ids'])):
logger.info(f"COLLATOR_DEBUG: Batch_output_{i}: input_ids_first_10={batch['input_ids'][i][:10].tolist()}")

batch["position_ids"] = torch.clip(torch.cumsum(batch["attention_mask"], dim=-1) - 1, min=0, max=None)
un_padded_batch = collate_fn_to_dict_list(un_padded_features)
batch.update(un_padded_batch)
Expand Down
677 changes: 597 additions & 80 deletions roll/distributed/scheduler/generate_scheduler.py

Large diffs are not rendered by default.

82 changes: 82 additions & 0 deletions roll/distributed/scheduler/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from torch.utils.data import DataLoader

from roll.utils.functionals import union_two_dict, divide_by_chunk_size
from roll.utils.logging import get_logger

logger = get_logger()

try:
tensordict.set_lazy_legacy(False).set()
Expand Down Expand Up @@ -606,7 +609,86 @@ def concat(data: List["DataProto"]) -> "DataProto":
for batch in data:
if batch.batch is not None:
batch_lst.append(batch.batch)

if len(batch_lst) > 0 and batch_lst[0] is not None:
# Add comprehensive logging to verify tensor dimensions before concatenation
logger.info(f"[CONCAT] Attempting to concatenate {len(batch_lst)} batches")

# Collect all tensor keys and verify dimensions
all_keys = set()
for i, batch in enumerate(batch_lst):
all_keys.update(batch.keys())

logger.info(f"[CONCAT] Found tensor keys: {list(all_keys)}")

# Verify dimensions for each key across all batches
dimension_mismatches = []
for key in all_keys:
shapes = []
request_ids = []

for i, batch in enumerate(batch_lst):
if key in batch:
shape = batch[key].shape
shapes.append(shape)
# Try to get request_id from meta_info if available
request_id = "unknown"
if i < len(data) and data[i].meta_info and "request_id" in data[i].meta_info:
request_id = data[i].meta_info["request_id"]
request_ids.append(request_id)

# Check for dimension mismatches (excluding dim=0 which should vary)
if len(shapes) > 1:
expected_shape = shapes[0][1:] # All dimensions except batch dimension
for j, shape in enumerate(shapes[1:], 1):
actual_shape = shape[1:] # All dimensions except batch dimension
if expected_shape != actual_shape:
mismatch_info = {
"key": key,
"expected_shape": expected_shape,
"actual_shape": actual_shape,
"expected_request_id": request_ids[0],
"mismatched_request_id": request_ids[j],
"all_shapes": shapes,
"all_request_ids": request_ids
}
dimension_mismatches.append(mismatch_info)

logger.error(f"[CONCAT] DIMENSION MISMATCH for key '{key}':")
logger.error(f" Expected shape: {expected_shape} (request_id: {request_ids[0]})")
logger.error(f" Actual shape: {actual_shape} (request_id: {request_ids[j]})")
logger.error(f" All shapes for this key: {shapes}")
logger.error(f" All request IDs for this key: {request_ids}")

# Log summary of all tensor shapes for debugging
logger.info(f"[CONCAT] Summary of tensor shapes across all batches:")
for key in all_keys:
shapes = []
request_ids = []
for i, batch in enumerate(batch_lst):
if key in batch:
shapes.append(batch[key].shape)
request_id = "unknown"
if i < len(data) and data[i].meta_info and "request_id" in data[i].meta_info:
request_id = data[i].meta_info["request_id"]
request_ids.append(request_id)
logger.info(f" Key '{key}': shapes={shapes}, request_ids={request_ids}")

# If there are dimension mismatches, log detailed error and raise exception
if dimension_mismatches:
logger.error(f"[CONCAT] Found {len(dimension_mismatches)} dimension mismatches:")
for mismatch in dimension_mismatches:
logger.error(f" Key: {mismatch['key']}")
logger.error(f" Expected: {mismatch['expected_shape']} (request_id: {mismatch['expected_request_id']})")
logger.error(f" Actual: {mismatch['actual_shape']} (request_id: {mismatch['mismatched_request_id']})")
logger.error(f" All shapes: {mismatch['all_shapes']}")
logger.error(f" All request IDs: {mismatch['all_request_ids']}")

# Raise the original error with additional context
raise RuntimeError(f"Tensor dimension mismatch detected. Found {len(dimension_mismatches)} mismatches. Check logs for details.")

# If all dimensions match, proceed with concatenation
logger.info(f"[CONCAT] All tensor dimensions verified. Proceeding with concatenation.")
new_batch = torch.cat(batch_lst, dim=0)
else:
new_batch = None
Expand Down
Loading