-
Notifications
You must be signed in to change notification settings - Fork 503
Open
Description
I'm trying to create an abstract dataset
import json
import logging
from slime.rollout.data_source import DataSource
from slime.utils.processing_utils import load_processor, load_tokenizer
from slime.utils.types import Sample
logger = logging.getLogger(__name__)
class SFTDataSource(DataSource):
def __init__(self, args):
self.args = args
self.tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True)
self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True)
self.file = args.prompt_data
if not self.file:
raise ValueError("prompt_data argument is required for SFTDataSource")
self.stream = open(self.args.prompt_data)
def get_samples(self, num_samples: int) -> list[list[Sample]]:
"""Return num_samples samples"""
samples = []
while len(samples) < num_samples:
line = self.stream.readline()
# EOF - reset to beginning
if not line:
self.stream.seek(0)
data = json.loads(line)
if self.processor:
from slime.utils.processing_utils import process_vision_info
assert isinstance(
data, list
), f"prompt must be a list when processor is not None, got {type(data)} instead"
multimodal_inputs = process_vision_info(data, self.processor)
metadata = data.get("metadata", {})
if "id" in data:
metadata["id"] = data["id"]
samples.append([Sample(
prompt=data,
metadata=metadata,
multimodal_inputs=multimodal_inputs,
)])
return samples
def add_samples(self, samples: list[list[Sample]]):
"""Add samples back to the data source"""
pass
def save(self, rollout_id):
"""Save state for checkpointing"""
pass
def load(self, rollout_id=None):
"""Load state from checkpoint"""
pass
The content of this loader isn't what's important - it'll be expanded to support my use cases very soon.
What is throwing is the following:
---------------------------------------
Job 'raysubmit_bnqeJDFmQkNanYpK' failed
---------------------------------------
Status message: Job entrypoint command failed with exit code 1, last available logs (truncated to 20,000 chars):
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AttributeError): ray::RolloutManager.get_num_rollout_per_epoch() (pid=1559250, ip=192.168.142.26, actor_id=e87f307ccebbed94bbf67f7602000000, repr=<slime.ray.rollout.RolloutManager object at 0x7fc3a1df15e0>)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/slime/slime/ray/rollout.py", line 127, in get_num_rollout_per_epoch
return len(self.data_source.dataset) // self.args.rollout_batch_size
^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'SFTDataSource' object has no attribute 'dataset'
It'd be great if you can suggest the best way to overcome this, especially as we want to have a streaming dataset since the current toy example's pattern of initializing the whole dataset in the constructor is not scalable. I'm happy to implement the fix myself based on your guidance.
Thanks!
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels