Skip to content

Commit

Permalink
fix format, bug, and readme
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzuxin committed Oct 7, 2023
1 parent f6d6479 commit e466e8d
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 21 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,17 @@
The Fast Safe Reinforcement Learning (FSRL) package provides modularized implementations
of Safe RL algorithms based on PyTorch and the [Tianshou](https://tianshou.readthedocs.io/en/master/) framework. Safe RL is a rapidly evolving subfield of RL, focusing on ensuring the safety of learning agents during the training and deployment process. The study of Safe RL is essential because it addresses the critical challenge of preventing unintended or harmful actions while still optimizing an agent's performance in complex environments.

This project offers high-quality and fast implementations of popular Safe RL algorithms, serving as an ideal starting point for those looking to explore and experiment in this field. By providing a comprehensive and accessible toolkit, the FSRL package aims to accelerate research in this crucial area and contribute to the development of safer and more reliable RL-powered systems.

**Please note that this project is still under active development, and major updates might be expected.** Your feedback and contributions are highly appreciated, as they help us improve the FSRL package.
This project offers high-quality and fast implementations of popular Safe RL algorithms, serving as an ideal starting point for those looking to explore and experiment in this field. By providing a comprehensive and accessible toolkit, the FSRL package aims to accelerate research in this crucial area and contribute to the development of safer and more reliable RL-powered systems. Your feedback and contributions are highly appreciated, as they help us improve the FSRL package.

If you find this code useful, please cite:
```bibtex
@article{liu2023datasets,
title={Datasets and Benchmarks for Offline Safe Reinforcement Learning},
author={Liu, Zuxin and Guo, Zijian and Lin, Haohong and Yao, Yihang and Zhu, Jiacheng and Cen, Zhepeng and Hu, Hanjiang and Yu, Wenhao and Zhang, Tingnan and Tan, Jie and others},
journal={arXiv preprint arXiv:2306.09303},
year={2023}
}
```

## 🌟 Key Features

Expand Down
3 changes: 1 addition & 2 deletions docs/tutorials/benchmark.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Safety-Gymnasium-Navigation-Tasks
<option value="SafetyPointGoal1Gymnasium-v0">SafetyPointGoal1Gymnasium-v0</option>
<option value="SafetyPointGoal2Gymnasium-v0">SafetyPointGoal2Gymnasium-v0</option>
<option value="SafetyPointPush1Gymnasium-v0">SafetyPointPush1Gymnasium-v0</option>
<option value="SafetyPointPush2Gymnasium-v0">SafetyPointPush2Gymnasium-v0</option>
<option value="SafetyPointPush2Gymnasium-v0">SafetyPointPush2Gymnasium-v0</option>
<!-- Add more options as needed -->
</select>

Expand All @@ -111,4 +111,3 @@ Safety-Gymnasium-Navigation-Tasks
</div>

<script src="../_static/js/benchmark.js"></script>

11 changes: 7 additions & 4 deletions fsrl/agent/cpo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,13 @@ def __init__(
self.logger = logger
self.cost_limit = cost_limit

if np.isscalar(cost_limit):
cost_dim = 1
else:
raise RuntimeError("CPO does not support multiple costs. \n Please refer to Page 5 of http://proceedings.mlr.press/v70/achiam17a/achiam17a.pdf for related discussions.")
if not np.isscalar(cost_limit):
raise RuntimeError(
"CPO does not support multiple costs. \n \
Please refer to Page 5 of \
http://proceedings.mlr.press/v70/achiam17a/achiam17a.pdf \
for related discussions."
)

# set seed and computing
seed_all(seed)
Expand Down
6 changes: 3 additions & 3 deletions fsrl/agent/cvpo_agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Optional, Tuple

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import numpy as np
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb
from torch.distributions import Independent, Normal
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(

self.logger = logger
self.cost_limit = cost_limit

if np.isscalar(cost_limit):
cost_dim = 1
else:
Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(

critics = []

for _ in range(1+cost_dim):
for _ in range(1 + cost_dim):
if double_critic:
net1 = Net(
state_shape,
Expand Down
4 changes: 2 additions & 2 deletions fsrl/agent/ddpg_lag_agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Optional, Tuple

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import numpy as np
from tianshou.exploration import GaussianNoise
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic
Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(
cost_dim = 1
else:
cost_dim = len(cost_limit)

nets = [
Net(
state_shape,
Expand Down
2 changes: 1 addition & 1 deletion fsrl/agent/sac_lag_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
unbounded=unbounded
).to(device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=actor_lr)

critics = []
for _ in range(1 + cost_dim):
net1 = Net(
Expand Down
1 change: 0 additions & 1 deletion fsrl/agent/trpo_lag_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def __init__(
else:
cost_dim = len(cost_limit)


# set seed and computing
seed_all(seed)
torch.set_num_threads(thread)
Expand Down
2 changes: 1 addition & 1 deletion fsrl/data/fast_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None:
else: # ReplayBuffer or PrioritizedReplayBuffer
assert buffer.maxsize > 0
if self.env_num > 1:
if type(buffer) == ReplayBuffer:
if isinstance(buffer) == ReplayBuffer:
buffer_type = "ReplayBuffer"
vector_type = "VectorReplayBuffer"
else:
Expand Down
10 changes: 6 additions & 4 deletions fsrl/policy/cvpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,12 @@ def __init__(
self.dtype = next(self.actor.parameters()).dtype
self.cost_limit = [cost_limit] * (self.critics_num -
1) if np.isscalar(cost_limit) else cost_limit

self.max_episode_steps = max_episode_steps
# qc threshold in the E-step
self.qc_thres = [
c * (1 - self._gamma**max_episode_steps) / (1 - self._gamma) /
max_episode_steps for c in self.cost_limit
c * (1 - self._gamma**self.max_episode_steps) / (1 - self._gamma) /
self.max_episode_steps for c in self.cost_limit
]

# E-step init
Expand Down Expand Up @@ -169,8 +171,8 @@ def update_cost_limit(self, cost_limit: float):
1) if np.isscalar(cost_limit) else cost_limit

self.qc_thres = [
c * (1 - self._gamma**max_episode_steps) / (1 - self._gamma) /
max_episode_steps for c in self.cost_limit
c * (1 - self._gamma**self.max_episode_steps) / (1 - self._gamma) /
self.max_episode_steps for c in self.cost_limit
]

def pre_update_fn(self, **kwarg: Any) -> Any:
Expand Down

0 comments on commit e466e8d

Please sign in to comment.