Skip to content

[Discussion] Remember TorchRL: the state of memory in TorchRL #2325

Open
@matteobettini

Description

@matteobettini

Remember TorchRL: the state of memory in TorchRL

Hello! This is a discussion post to recap the state of memory models in TorchRL: what's doable, what's not doable, what is the way to do things, and what is missing.

Why

The goal of this post is to outline what we can imporve in the library to allow the users to have access to state-of-the-art memory models to tackle partially observable (PO) problems

DIsclaimer

I am not a memory guy. In fact, I cannot even remember any birthdays and have to put them in my calendar. The idea of this post is to stimulate discussion with memory researchers. So if anything I write here is incorrect, please point it out so I can learn and fix it.

The state of memory in RL

Traditional architectures are GRUs and LSTMs (available in this library). Despite GRUs have been shown to have good results in PO problems https://arxiv.org/abs/2303.01859, they have problems related to the way they process the sequence.

In fact, to run LSTMs and GRUs parallely over time you need a fixed sequence length without any terminations within it. The way this has been traditionally tackled (and it is tackled in this library) is to split and pad trajectories to a fixed length that does not contain any dones.

This has some major issues. Most importantly: the length of the sequence is an hyperparamter that will affect directly the memory length. A low value will impede remembering things too far in the past and a high value will cause high padding and inefficiency.

Recently, a new class of sequence models, sometimes called Linear Recurrent Models or Linear Transformers has been introduced.
These models can be run parallely in the time dimension with subquadratic space complecity. Examples are S5 and Fast and Forgetful Memory

Most importantly, these models do not require a fixed sequence length (and thus padding) and can be run on consecutive terminating trajectories.

To be precise, LSTM and GRU could be also utilised without padding, but this would mean calling them in a for loop in their "Cell" version, which would lead to better results, but higher complecity (as they could not be batched over time)

What is available in TorchRL

Models:

  • LSTMCell -> runs a single pass of an LSTM
  • LSTM -> runs multiple passes of an LSTM on a batch of data with a fixed sequence length
  • GRUCell -> runs a single pass of a GRU
  • GRU -> runs multiple passes of a GRU on a batch of data with a fixed sequence length

Tensordict modules:

  • GRU and LSTM modules to simplify the use of the above in torchrl

Tutorials:

  • RECURRENT DQN: TRAINING RECURRENT POLICIES -> shows how to train DQN with a policy that uses an LSTM. The LSTM is use on 1 step during collection and on a fixed sequence length during training. Trajectories are split into padded sequences to achieve this. Suche that there are no episode dones within a sequence.

Replay buffer:

  • SliceSampler -> enables to store and sample trajectories by concatenating them along the time dimension and provide flexibility in how we want to sample them. This enables the type of buffer described in https://arxiv.org/abs/2402.09900
  • We can also stack trajectories along a fixed and padded time dimension like traditional buffers (inefficient)

What is doable in TorchRL

  • TorchRL buffers and samplers can deal with both padded sequences and concatenated sequences (aligning with the buffer proposed in https://arxiv.org/abs/2402.09900)
  • TorchRL can run LSTMs and GRUs both on single steps and parallely on padded trajectories
  • It is also possoble to avoid padding and run LSTM and GRU cells during training using for loops over time. leading to more correct results but worse performance

What is not available in TorchRL

Implementing Linear Recurrent Models efficiently in torch is currently not possible due to the lack of Parallel Associative Scan (pytorch/pytorch#95408), a feature that, on the contrary, is available in JAX and has allowed progress in RL memory research.

These models could still be implemented in their "cell" version (for loop instead of parallel scan). This would be quite inefficient although it would lead to better performance wrt traditional architectures.

What are the next steps?

Easy

I believe that we should make it easier for users to approach memory models in torchrl. This could be done by adding a new tutorial that shows how to use the SliceSampler in a memory context (maybe with a GRU since the other tutorial uses LSTM).

It would also be cool to have a tutorial that shows users how to use the current memory models without padding (in their cell version). Which ia a better although more inefficient implementation.

Medium

We could implement Linear Recurrent Models in their inefficient version, avoiding the use of the Parallel Associative Scan

Hard

Introdice Parallel Associative Scan to pytorch (pytorch/pytorch#95408), allowing us to code the state-of-the-art memory models and catch up with JAX.

Conclusions

I hope to have given a complete picture of the state of memory in TorchRL, if anything is missing or incorrect, please point it out and I will update the comment

I am happy to put coding work into a direction once we have identified one. My only limitation is that I would probably not be able to add torch c++ or cuda kernels without a signficant learning experinece

cc @vmoens @albertbou92 @smorad @EdanToledo

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions