Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Device Error in PrioritizedSliceSampler.sample() #2207

Closed
3 tasks done
wertyuilife2 opened this issue Jun 6, 2024 · 0 comments
Closed
3 tasks done

[BUG] Device Error in PrioritizedSliceSampler.sample() #2207

wertyuilife2 opened this issue Jun 6, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@wertyuilife2
Copy link

wertyuilife2 commented Jun 6, 2024

Describe the bug

This issue comes from the original issue #2205.

In commit c2e1c05, at samplers.py line 1767 and 1213, index and stop_idx might not be on the same device, with stop_idx potentially being on the GPU. These lines should be modified as follows:

index[:, 0].unsqueeze(0)==stop_idx[:, 0].unsqueeze(1).to(index.device)

To Reproduce

The code below will cause device error.

import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
from tensordict import TensorDict

def test_sampler():
    torch.manual_seed(0)

    sampler = PrioritizedSliceSampler(
        max_capacity=20,
        num_slices=2,
        traj_key="trajectory",
        strict_length=True,
        alpha=1.0,
        beta=1.0,
    )
    trajectory = torch.tensor([3, 3, 0, 1, 1, 1, 2, 2, 2, 3])
    td = TensorDict({"trajectory": trajectory, "steps": torch.arange(10)}, [10])
    rb = ReplayBuffer(
        sampler=sampler,
        storage=LazyTensorStorage(20, device=torch.device("cuda")),
        batch_size=6,
    )

    rb.extend(td)
    for i in range(10):
        traj = rb.sample()["trajectory"]
        print("[loop {}]sampled trajectory: {}".format(i, traj))
test_sampler()

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@wertyuilife2 wertyuilife2 added the bug Something isn't working label Jun 6, 2024
@vmoens vmoens closed this as completed Jun 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants