Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Inadequate Default Priority Design in PrioritizedSampler #2210

Closed
3 tasks done
wertyuilife2 opened this issue Jun 6, 2024 · 5 comments · Fixed by #2215
Closed
3 tasks done

[BUG] Inadequate Default Priority Design in PrioritizedSampler #2210

wertyuilife2 opened this issue Jun 6, 2024 · 5 comments · Fixed by #2215
Assignees
Labels
bug Something isn't working

Comments

@wertyuilife2
Copy link

wertyuilife2 commented Jun 6, 2024

This issue comes from the original issue #2205.

Current Implementation and Issues

The current implementation maintains _max_priority, which represents the maximum priority of all samples historically, not just the current buffer. Early in RL training, outliers can cause _max_priority to remain high, making it unrepresentative. Additionally, _max_priority is initialized to 1, while most RL algorithms use Bellman error as priority, which can often be much smaller (close to 0). Consequently, _max_priority may never be updated. New samples are thus given a priority of 1, which essentially means their PER weight is close to 0. This means they are sampled immediately but contribute little to the weighted loss, reducing sample efficiency.

Proposed Solution

Maintain a _neg_min_tree = MinSegmentTree() to track the maximum priority in the current buffer. With this, and add self._upper_priority = 1, part of PrioritizedSampler methods can be updated as follows:

def default_priority(self, storage):
    max_priority = min(-self._neg_min_tree.query(0, len(storage)), self._upper_priority)
    if max_priority == 0:
        return self._upper_priority**self._alpha
    return (max_priority + self._eps) ** self._alpha

def mark_update(self, index, storage):
    self._neg_min_tree[index] = 0  # update negmintree before querying it in default_priority
    self.update_priority(index, self.default_priority(storage))

@torch.no_grad()
def _add_or_extend(self, index, storage):
    self._neg_min_tree[index] = 0  # update negmintree before querying it in default_priority
    priority = self.default_priority(storage)

    # ....codes in mid...

    self._sum_tree[index] = priority
    self._min_tree[index] = priority

@torch.no_grad()
def update_priority(self, index, priority):

    priority = torch.as_tensor(priority, device=torch.device("cpu")).detach()
    index = torch.as_tensor(index, dtype=torch.long, device=torch.device("cpu"))

    #...codes in mid...

    priority = torch.pow(priority + self._eps, self._alpha).clamp_max(
        self._upper_priority
    )
    self._sum_tree[index] = priority
    self._min_tree[index] = priority
    self._neg_min_tree[index] = -priority

This change implies that the default_priority function will need to take storage as an additional parameter, and eventually affecting several methods like Sampler(ABC).extend(), Sampler(ABC).add(), and Sampler(ABC).mark_update(), but I believe this is reasonable, akin to how Sampler.sample() already takes storage as a parameter.

Additional Context

See discussion in the original issue #2205.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@wertyuilife2 wertyuilife2 added the bug Something isn't working label Jun 6, 2024
@vmoens vmoens linked a pull request Jun 7, 2024 that will close this issue
@wertyuilife2
Copy link
Author

@vmoens I found that #2215 does not actually work well :O

#2215 may erases the previous max_priority when sampler.add() and sampler.extend(), but the max priority within the buffer can decrease with update_priority(). For example, let’s consider a buffer with three samples:

index: [0,1,2]
priority: [3,2,4]
max_priority_within_buffer: 4

update_priority(index=2, priority=1)

index: [0,1,2]
priority: [3,2,1]
max_priority_within_buffer: 3

So, it seems that the only way is to compute max_priority on the fly, using something like the negative min tree approach I proposed in #2205. For time consumption concerns, the tree structure can perform queries and inserts in O(log N), like the min_tree already did.

@vmoens
Copy link
Contributor

vmoens commented Jun 13, 2024

Hmm in this case I think we can just add a check in update_priority for the provided index against the one recorded in max_priority no?
I don't get why we need to compute it on the fly if we have access to the index?

@wertyuilife2
Copy link
Author

wertyuilife2 commented Jun 13, 2024

I think we cannot determine what the new max priority will be after erasing the old max priority (we need to find the second largest priority in the buffer).

@vmoens
Copy link
Contributor

vmoens commented Jun 13, 2024

Oh yeah but not every time the buffer is written!
Just when you update the priority and the index matches the one of the current max.
So I think the logic should be: cache the max. If the extend / update_priority index matches the index of the current max, erase that cache. Next time it's accessed, rescan the whole tree to find the current max.

@wertyuilife2
Copy link
Author

wow this seems like a feasible solution, I agree

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants