Description
This issue comes from the original issue #2205.
Current Implementation and Issues
The current implementation maintains _max_priority
, which represents the maximum priority of all samples historically, not just the current buffer. Early in RL training, outliers can cause _max_priority
to remain high, making it unrepresentative. Additionally, _max_priority
is initialized to 1, while most RL algorithms use Bellman error as priority, which can often be much smaller (close to 0). Consequently, _max_priority
may never be updated. New samples are thus given a priority of 1, which essentially means their PER weight is close to 0. This means they are sampled immediately but contribute little to the weighted loss, reducing sample efficiency.
Proposed Solution
Maintain a _neg_min_tree = MinSegmentTree()
to track the maximum priority in the current buffer. With this, and add self._upper_priority = 1
, part of PrioritizedSampler
methods can be updated as follows:
def default_priority(self, storage):
max_priority = min(-self._neg_min_tree.query(0, len(storage)), self._upper_priority)
if max_priority == 0:
return self._upper_priority**self._alpha
return (max_priority + self._eps) ** self._alpha
def mark_update(self, index, storage):
self._neg_min_tree[index] = 0 # update negmintree before querying it in default_priority
self.update_priority(index, self.default_priority(storage))
@torch.no_grad()
def _add_or_extend(self, index, storage):
self._neg_min_tree[index] = 0 # update negmintree before querying it in default_priority
priority = self.default_priority(storage)
# ....codes in mid...
self._sum_tree[index] = priority
self._min_tree[index] = priority
@torch.no_grad()
def update_priority(self, index, priority):
priority = torch.as_tensor(priority, device=torch.device("cpu")).detach()
index = torch.as_tensor(index, dtype=torch.long, device=torch.device("cpu"))
#...codes in mid...
priority = torch.pow(priority + self._eps, self._alpha).clamp_max(
self._upper_priority
)
self._sum_tree[index] = priority
self._min_tree[index] = priority
self._neg_min_tree[index] = -priority
This change implies that the default_priority
function will need to take storage
as an additional parameter, and eventually affecting several methods like Sampler(ABC).extend()
, Sampler(ABC).add()
, and Sampler(ABC).mark_update()
, but I believe this is reasonable, akin to how Sampler.sample()
already takes storage as a parameter.
Additional Context
See discussion in the original issue #2205.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)