-
Notifications
You must be signed in to change notification settings - Fork 406
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Motivation
We need a collector that fits well the LLM space.
We will need to simplify the rollout function greatly - I would rewrite it from scratch.
The LLMEnv and vLLMWrapper can be used to simulate a real-life collector (just create a dummy dataloader that outputs strings and call LLMEnv.from_dataloader
, see the tests for examples of this).
Features
We want this collector to have the right arguments:
- rename "frames" to "steps" (or a more LLM appropriate term)
- We just want a (list of) env constructors and (a list of) policy constructors (Note: currently you can only pass a single policy constructor)
We want to remove the following params:
- "Devices" may not be useful in the LLM space: the policy is unlikely to sit on a single device, and the env should probably handle the device casting on its own (I'm not 100% confident about this one except for the fact that I don't think the current way we handle devices is proper)
- Remove
create_env_kwargs
as you can always usefunctools.partial
to pass kwargs to a specific env constructor init_random_frames
will never be usedsplit_trajs
: not sure that is very useful for LLMs eitherexploration_type
: no need to explore - the LLM engine has its own exploration logicinterruptor
: I would deactivate that in the v0 of the LLM collectoruse_buffers
should always be set toFalse
since we're going to pass non-tensor datatrust_policy
andcompile_policy
: I believe the policy constructor should take care of that.no_cuda_sync
will also probably not be useful
Additional options and features (stretch goals):
- We want this buffer to work with ray like we do with the RayCollector. Since the RayCollector also has too many params, we should find a way of dealing with that (one way could be to do this:
that way people can just ask for Ray (or not!) and pass the ray kwargs and tada! the collector uses ray (or not!)
class _DCMeta(...): def __call__(cls, *args, **kwargs): backend = kwargs.pop("backend") if backend: return RayCollector(**my_cuisine()) return super().__call__(*args, **kwargs) class LLMCollector(DataCollector, metaclass=_DCMeta): def __init__(self, ..., backend: Literal["ray"] | None=None): ...
- We want to be able to get data in the following ways: (1) iterating over the collector, (2) calling
next
, (3) async using start + a replay buffer where the data is written and (4) using start + a queue.
Lucaskabela
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request