Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prioritized experience replay #1622

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c607585
Started PER
AlexPasqua Jul 22, 2023
57f1192
Added "add" method + other improvements
AlexPasqua Jul 23, 2023
2b9df33
Docstrings, type hints, doc
AlexPasqua Jul 23, 2023
31dc46c
Merge branch 'master' into prioritized-experience-replay
AlexPasqua Jul 23, 2023
1a32377
Merge branch 'master' into prioritized-experience-replay
araffin Jul 24, 2023
ccf7dc3
Merge branch 'master' into prioritized-experience-replay
AlexPasqua Aug 6, 2023
aee1d30
FIxed for pytype checks (partially)
AlexPasqua Aug 6, 2023
c51b173
make format
AlexPasqua Aug 6, 2023
18c9d28
Made pytype ignore type on PER's sample method
AlexPasqua Aug 6, 2023
840dde2
Merge branch 'master' into prioritized-experience-replay
araffin Aug 30, 2023
dcfbf88
Merge branch 'master' into prioritized-experience-replay
araffin Sep 29, 2023
fb33732
Switch to numpy for the backend
araffin Sep 29, 2023
f984e5c
Move to common and add tests
araffin Sep 29, 2023
5edf8bf
Updated DQN docs
AlexPasqua Sep 30, 2023
2f76038
Update doc
araffin Oct 2, 2023
42f2f4a
Rename things to be consistent with buffers.py
araffin Oct 2, 2023
007105f
Rename variables and add priority update
araffin Oct 2, 2023
cc37cba
Ignore mypy
araffin Oct 2, 2023
b60ef03
Add beta schedule
araffin Oct 3, 2023
ec272b9
Merge branch 'master' into prioritized-experience-replay
araffin Nov 8, 2023
a043cfd
Merge branch 'master' into prioritized-experience-replay
araffin Nov 22, 2023
f6accf9
Merge branch 'master' into prioritized-experience-replay
araffin Jan 30, 2024
b21ef33
Merge branch 'master' into prioritized-experience-replay
araffin May 6, 2024
f57444a
Merge branch 'master' into prioritized-experience-replay
araffin May 24, 2024
be00231
Minor fix in PER
AlexPasqua May 24, 2024
4390ec7
Merge branch 'master' into prioritized-experience-replay
araffin Jun 7, 2024
bee9cbe
Merge branch 'master' into prioritized-experience-replay
araffin Jul 7, 2024
fb1a9f7
Only convert to numpy if needed
araffin Jul 12, 2024
150b09a
Increase min priority to avoid division by zero
araffin Jul 12, 2024
5c0c79d
Merge branch 'master' into prioritized-experience-replay
araffin Jul 17, 2024
148e4aa
Merge branch 'master' into prioritized-experience-replay
araffin Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ New Features:
^^^^^^^^^^^^^
- Added Python 3.11 support
- Added Gymnasium 0.29 support (@pseudo-rnd-thoughts)
- Prioritized Experience Replay for DQN (@AlexPasqua)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down
3 changes: 2 additions & 1 deletion docs/modules/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ Notes
- Further reference: https://www.nature.com/articles/nature14236

.. note::
This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay.
This implementation does **not** provide Rainbow DQN, but only vanilla Deep Q-Learning.
There are no extensions such as Double-DQN or Dueling-DQN, with the only exception being Prioritized Experience Replay.


Can I use?
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB
:param batch_size: Number of element to sample
:param env: associated gym VecEnv
to normalize the observations/rewards when sampling
:return:
:return: a batch of sampled experiences from the buffer.
"""
if not self.optimize_memory_usage:
return super().sample(batch_size=batch_size, env=env)
Expand Down
221 changes: 221 additions & 0 deletions stable_baselines3/dqn/prioritized_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import random
import warnings
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch as th
from gymnasium import spaces

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.type_aliases import ReplayBufferSamples
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize


class SumTree:
"""
SumTree data structure for Prioritized Replay Buffer.
This code is inspired by: https://github.com/Howuhh/prioritized_experience_replay

:param size: Max number of element in the buffer.
"""

def __init__(self, size: int):
self.nodes = th.zeros(2 * size - 1)
self.data = th.empty(size)
self.size = size
self.count = 0
self.real_size = 0

@property
def p_total(self) -> float:
"""
Returns the root node value, which represents the total sum of all priorities in the tree.

:return: Total sum of all priorities in the tree.
"""
return self.nodes[0].item()

def update(self, data_idx: int, value: float):
"""
Update the priority of a leaf node.

:param data_idx: Index of the leaf node to update.
:param value: New priority value.
"""
idx = data_idx + self.size - 1 # child index in tree array
change = value - self.nodes[idx]
self.nodes[idx] = value
parent = (idx - 1) // 2
while parent >= 0:
self.nodes[parent] += change
parent = (parent - 1) // 2

def add(self, value: float, data: int):
"""
Add a new transition with priority value.

:param value: Priority value.
:param data: Transition data.
"""
self.data[self.count] = data
self.update(self.count, value)
self.count = (self.count + 1) % self.size
self.real_size = min(self.size, self.real_size + 1)

def get(self, cumsum) -> tuple[int, float, th.Tensor]:
"""
Get a leaf node index, its priority value and transition data by cumsum value.

:param cumsum: Cumulative sum value.
:return: Leaf node index, its priority value and transition data.
"""
assert cumsum <= self.p_total

idx = 0
while 2 * idx + 1 < len(self.nodes):
left, right = 2 * idx + 1, 2 * idx + 2
if cumsum <= self.nodes[left]:
idx = left
else:
idx = right
cumsum = cumsum - self.nodes[left]

data_idx = idx - self.size + 1
return data_idx, self.nodes[idx].item(), self.data[data_idx]

def __repr__(self):
return f"SumTree(nodes={self.nodes.__repr__()}, data={self.data.__repr__()})"


class PrioritizedReplayBuffer(ReplayBuffer):
"""
Prioritized Replay Buffer.
Paper: https://arxiv.org/abs/1511.05952
This code is inspired by: https://github.com/Howuhh/prioritized_experience_replay

:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param alpha: How much prioritization is used (0 - no prioritization, 1 - full prioritization)
:param beta: To what degree to use importance weights (0 - no corrections, 1 - full correction)
"""

def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
alpha: float = 0.6,
beta: float = 0.4,
optimize_memory_usage: bool = False,
):
super().__init__(buffer_size, observation_space, action_space, device, n_envs)

# TODO: check this
if optimize_memory_usage:
warnings.warn("PrioritizedReplayBuffer does not support optimize_memory_usage=True during sampling")

# PER params
self.eps = 1e-8 # minimal priority, prevents zero probabilities
self.alpha = alpha # determines how much prioritization is used, alpha = 0 corresponding to the uniform case
self.beta = beta # determines the amount of importance-sampling correction
self.max_priority = self.eps # priority for new samples, init as eps

# SumTree: data structure to store priorities
self.tree = SumTree(size=buffer_size)

self.real_size = 0
self.count = 0

def add(
self,
obs: np.ndarray,
next_obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:
"""
Add a new transition to the buffer.

:param obs: Starting observation of the transition to be stored.
:param next_obs: Destination observation of the transition to be stored.
:param action: Action performed in the transition to be stored.
:param reward: Reward received in the transition to be stored.
:param done: Whether the episode was finished after the transition to be stored.
:param infos: Eventual information given by the environment.
"""
# store transition index with maximum priority in sum tree
self.tree.add(self.max_priority, self.count)

# update counters
self.count = (self.count + 1) % self.buffer_size
self.real_size = min(self.buffer_size, self.real_size + 1)

# store transition in the buffer
super().add(obs, next_obs, action, reward, done, infos)

def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
"""
Sample elements from the prioritized replay buffer.

:param batch_size: Number of element to sample
:param env:associated gym VecEnv
to normalize the observations/rewards when sampling
:return: a batch of sampled experiences from the buffer.
"""
assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires."

sample_idxs, tree_idxs = [], []
priorities = th.empty(batch_size, 1, dtype=th.float)

# To sample a minibatch of size k, the range [0, p_total] is divided equally into k ranges.
# Next, a value is uniformly sampled from each range. Finally the transitions that correspond
# to each of these sampled values are retrieved from the tree.
segment = self.tree.p_total / batch_size
for i in range(batch_size):
# extremes of the current segment
a, b = segment * i, segment * (i + 1)

# uniformely sample a value from the current segment
cumsum = random.uniform(a, b)

# tree_idx is a index of a sample in the tree, needed further to update priorities
# sample_idx is a sample index in buffer, needed further to sample actual transitions
tree_idx, priority, sample_idx = self.tree.get(cumsum)

priorities[i] = priority
tree_idxs.append(tree_idx)
sample_idxs.append(int(sample_idx.item()))

# probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha
# where p_i > 0 is the priority of transition i.
probs = priorities / self.tree.p_total

# Importance sampling weights.
# All weights w_i were scaled so that max_i w_i = 1.
weights = (self.real_size * probs) ** -self.beta
weights = weights / weights.max()

env_indices = np.random.randint(0, high=self.n_envs, size=(len(sample_idxs),))

if self.optimize_memory_usage:
next_obs = self._normalize_obs(
self.observations[(np.array(sample_idxs) + 1) % self.buffer_size, env_indices, :], env
)
else:
next_obs = self._normalize_obs(self.next_observations[sample_idxs, env_indices, :], env)
araffin marked this conversation as resolved.
Show resolved Hide resolved

batch = (
self._normalize_obs(self.observations[sample_idxs, env_indices, :], env),
self.actions[sample_idxs, env_indices, :],
next_obs,
self.dones[sample_idxs],
self.rewards[sample_idxs],
)
return ReplayBufferSamples(*tuple(map(self.to_torch, batch)))
AlexPasqua marked this conversation as resolved.
Show resolved Hide resolved