Skip to content

[Feature Request] A collector designed for LLMs #2872

@vmoens

Description

@vmoens

Motivation

We need a collector that fits well the LLM space.

We will need to simplify the rollout function greatly - I would rewrite it from scratch.
The LLMEnv and vLLMWrapper can be used to simulate a real-life collector (just create a dummy dataloader that outputs strings and call LLMEnv.from_dataloader, see the tests for examples of this).

Features

We want this collector to have the right arguments:

  • rename "frames" to "steps" (or a more LLM appropriate term)
  • We just want a (list of) env constructors and (a list of) policy constructors (Note: currently you can only pass a single policy constructor)

We want to remove the following params:

  • "Devices" may not be useful in the LLM space: the policy is unlikely to sit on a single device, and the env should probably handle the device casting on its own (I'm not 100% confident about this one except for the fact that I don't think the current way we handle devices is proper)
  • Remove create_env_kwargs as you can always use functools.partial to pass kwargs to a specific env constructor
  • init_random_frames will never be used
  • split_trajs: not sure that is very useful for LLMs either
  • exploration_type: no need to explore - the LLM engine has its own exploration logic
  • interruptor: I would deactivate that in the v0 of the LLM collector
  • use_buffers should always be set to False since we're going to pass non-tensor data
  • trust_policy and compile_policy: I believe the policy constructor should take care of that.
  • no_cuda_sync will also probably not be useful

Additional options and features (stretch goals):

  • We want this buffer to work with ray like we do with the RayCollector. Since the RayCollector also has too many params, we should find a way of dealing with that (one way could be to do this:
    class _DCMeta(...):
         def __call__(cls, *args, **kwargs):
             backend = kwargs.pop("backend")
             if backend:
                 return RayCollector(**my_cuisine())
             return super().__call__(*args, **kwargs)
    
    class LLMCollector(DataCollector, metaclass=_DCMeta):
         def __init__(self, ..., backend: Literal["ray"] | None=None):
             ...
    that way people can just ask for Ray (or not!) and pass the ray kwargs and tada! the collector uses ray (or not!)
  • We want to be able to get data in the following ways: (1) iterating over the collector, (2) calling next, (3) async using start + a replay buffer where the data is written and (4) using start + a queue.

cc @mikaylagawarecki @Lucaskabela

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions