Skip to content

Creating a new Dataset source throws in rollout.py because len(data_source.dataset) is used. #1498

@TSunny007

Description

@TSunny007

I'm trying to create an abstract dataset


import json
import logging
from slime.rollout.data_source import DataSource
from slime.utils.processing_utils import load_processor, load_tokenizer
from slime.utils.types import Sample

logger = logging.getLogger(__name__)


class SFTDataSource(DataSource):
    def __init__(self, args):
        self.args = args
        self.tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True)
        self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True)
        self.file = args.prompt_data
        if not self.file:
            raise ValueError("prompt_data argument is required for SFTDataSource")
        self.stream = open(self.args.prompt_data)


    
    def get_samples(self, num_samples: int) -> list[list[Sample]]:
        """Return num_samples samples"""
        samples = []
        while len(samples) < num_samples:
            line = self.stream.readline()
            
            # EOF - reset to beginning
            if not line:
                self.stream.seek(0)
            
            data = json.loads(line)
            if self.processor:
                from slime.utils.processing_utils import process_vision_info

                assert isinstance(
                    data, list
                ), f"prompt must be a list when processor is not None, got {type(data)} instead"
                multimodal_inputs = process_vision_info(data, self.processor)

            metadata = data.get("metadata", {})
            if "id" in data:
                metadata["id"] = data["id"]
            
            samples.append([Sample(
                prompt=data,
                metadata=metadata,
                multimodal_inputs=multimodal_inputs,
            )])
        
        return samples

        
    def add_samples(self, samples: list[list[Sample]]):
        """Add samples back to the data source"""
        pass
        
    def save(self, rollout_id):
        """Save state for checkpointing"""
        pass
        
    def load(self, rollout_id=None):
        """Load state from checkpoint"""
        pass

The content of this loader isn't what's important - it'll be expanded to support my use cases very soon.

What is throwing is the following:

---------------------------------------
Job 'raysubmit_bnqeJDFmQkNanYpK' failed
---------------------------------------

Status message: Job entrypoint command failed with exit code 1, last available logs (truncated to 20,000 chars):
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AttributeError): ray::RolloutManager.get_num_rollout_per_epoch() (pid=1559250, ip=192.168.142.26, actor_id=e87f307ccebbed94bbf67f7602000000, repr=<slime.ray.rollout.RolloutManager object at 0x7fc3a1df15e0>)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/slime/slime/ray/rollout.py", line 127, in get_num_rollout_per_epoch
    return len(self.data_source.dataset) // self.args.rollout_batch_size
               ^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'SFTDataSource' object has no attribute 'dataset'

It'd be great if you can suggest the best way to overcome this, especially as we want to have a streaming dataset since the current toy example's pattern of initializing the whole dataset in the constructor is not scalable. I'm happy to implement the fix myself based on your guidance.
Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions