Skip to content

[Feature Request] Question about LSTMModules #3147

@itwasabhi

Description

@itwasabhi

Motivation

Could I get more details on tensor sizes required when using LSTMModule?

I'm particularly confused about the required input size with recurrent mode... seems like it needs to be (*b, T, Feature), and that the time dimension should not have consecutive trajectories stacked. And since recurrent_mode is enabled for all losses, doesn't this mean every time you call a loss function and happen to have a recurrent model, it needs to be batched in the same way?

I'm a bit confused how this example works at all: https://docs.pytorch.org/rl/main/tutorials/dqn_with_rnn.html. The data passed into the loss function is of size (B1, B2), where both batch dimensions have multiple slices concatentated together. Wouldn't this mean that during training the recurrent state is being passed between unrelated trajectories?

In [2]: s["collector", "traj_ids"]
Out[2]:
tensor([[10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12,
12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14],
[30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31,
31, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 33,
33, 33, 33, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34],
[20, 20, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 21, 21, 21, 21, 22, 22,
22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23,
23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]])

Solution

Documentation updates.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions