diff --git a/README.md b/README.md
index 13d4398..da216d7 100644
--- a/README.md
+++ b/README.md
@@ -11,6 +11,8 @@
[![License](https://img.shields.io/badge/License-MIT-yellow.svg)](#license)
[![CodeCov](https://codecov.io/github/liuzuxin/fsrl/branch/main/graph/badge.svg?token=BU27LTW9F3)](https://codecov.io/github/liuzuxin/fsrl)
[![Tests](https://github.com/liuzuxin/fsrl/actions/workflows/test.yml/badge.svg)](https://github.com/liuzuxin/fsrl/actions/workflows/test.yml)
+ [![GitHub Repo Stars](https://img.shields.io/github/stars/liuzuxin/fsrl?color=brightgreen&logo=github)](https://github.com/liuzuxin/fsrl/stargazers)
+ [![Downloads](https://static.pepy.tech/personalized-badge/fast-safe-rl?period=total&left_color=grey&right_color=blue&left_text=downloads)](https://pepy.tech/project/fast-safe-rl)
@@ -33,9 +35,17 @@
The Fast Safe Reinforcement Learning (FSRL) package provides modularized implementations
of Safe RL algorithms based on PyTorch and the [Tianshou](https://tianshou.readthedocs.io/en/master/) framework. Safe RL is a rapidly evolving subfield of RL, focusing on ensuring the safety of learning agents during the training and deployment process. The study of Safe RL is essential because it addresses the critical challenge of preventing unintended or harmful actions while still optimizing an agent's performance in complex environments.
-This project offers high-quality and fast implementations of popular Safe RL algorithms, serving as an ideal starting point for those looking to explore and experiment in this field. By providing a comprehensive and accessible toolkit, the FSRL package aims to accelerate research in this crucial area and contribute to the development of safer and more reliable RL-powered systems.
+This project offers high-quality and fast implementations of popular Safe RL algorithms, serving as an ideal starting point for those looking to explore and experiment in this field. By providing a comprehensive and accessible toolkit, the FSRL package aims to accelerate research in this crucial area and contribute to the development of safer and more reliable RL-powered systems. Your feedback and contributions are highly appreciated, as they help us improve the FSRL package.
-**Please note that this project is still under active development, and major updates might be expected.** Your feedback and contributions are highly appreciated, as they help us improve the FSRL package.
+If you find this code useful, please cite:
+```bibtex
+@article{liu2023datasets,
+ title={Datasets and Benchmarks for Offline Safe Reinforcement Learning},
+ author={Liu, Zuxin and Guo, Zijian and Lin, Haohong and Yao, Yihang and Zhu, Jiacheng and Cen, Zhepeng and Hu, Hanjiang and Yu, Wenhao and Zhang, Tingnan and Tan, Jie and others},
+ journal={arXiv preprint arXiv:2306.09303},
+ year={2023}
+}
+```
## 🌟 Key Features
diff --git a/docs/tutorials/benchmark.rst b/docs/tutorials/benchmark.rst
index a73e116..b3e385a 100644
--- a/docs/tutorials/benchmark.rst
+++ b/docs/tutorials/benchmark.rst
@@ -90,7 +90,7 @@ Safety-Gymnasium-Navigation-Tasks
-
+
@@ -111,4 +111,3 @@ Safety-Gymnasium-Navigation-Tasks
-
diff --git a/fsrl/agent/cpo_agent.py b/fsrl/agent/cpo_agent.py
index f35e83f..1c9a4d9 100644
--- a/fsrl/agent/cpo_agent.py
+++ b/fsrl/agent/cpo_agent.py
@@ -101,10 +101,13 @@ def __init__(
self.logger = logger
self.cost_limit = cost_limit
- if np.isscalar(cost_limit):
- cost_dim = 1
- else:
- raise RuntimeError("CPO does not support multiple costs. \n Please refer to Page 5 of http://proceedings.mlr.press/v70/achiam17a/achiam17a.pdf for related discussions.")
+ if not np.isscalar(cost_limit):
+ raise RuntimeError(
+ "CPO does not support multiple costs. \n \
+ Please refer to Page 5 of \
+ http://proceedings.mlr.press/v70/achiam17a/achiam17a.pdf \
+ for related discussions."
+ )
# set seed and computing
seed_all(seed)
diff --git a/fsrl/agent/cvpo_agent.py b/fsrl/agent/cvpo_agent.py
index b7185bf..3050eb3 100644
--- a/fsrl/agent/cvpo_agent.py
+++ b/fsrl/agent/cvpo_agent.py
@@ -1,9 +1,9 @@
from typing import Optional, Tuple
import gymnasium as gym
+import numpy as np
import torch
import torch.nn as nn
-import numpy as np
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb
from torch.distributions import Independent, Normal
@@ -118,7 +118,7 @@ def __init__(
self.logger = logger
self.cost_limit = cost_limit
-
+
if np.isscalar(cost_limit):
cost_dim = 1
else:
@@ -150,7 +150,7 @@ def __init__(
critics = []
- for _ in range(1+cost_dim):
+ for _ in range(1 + cost_dim):
if double_critic:
net1 = Net(
state_shape,
diff --git a/fsrl/agent/ddpg_lag_agent.py b/fsrl/agent/ddpg_lag_agent.py
index 02a2af4..c457f55 100644
--- a/fsrl/agent/ddpg_lag_agent.py
+++ b/fsrl/agent/ddpg_lag_agent.py
@@ -1,9 +1,9 @@
from typing import Optional, Tuple
import gymnasium as gym
+import numpy as np
import torch
import torch.nn as nn
-import numpy as np
from tianshou.exploration import GaussianNoise
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic
@@ -112,7 +112,7 @@ def __init__(
cost_dim = 1
else:
cost_dim = len(cost_limit)
-
+
nets = [
Net(
state_shape,
diff --git a/fsrl/agent/sac_lag_agent.py b/fsrl/agent/sac_lag_agent.py
index 3ab51ca..375f9f1 100644
--- a/fsrl/agent/sac_lag_agent.py
+++ b/fsrl/agent/sac_lag_agent.py
@@ -133,7 +133,7 @@ def __init__(
unbounded=unbounded
).to(device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=actor_lr)
-
+
critics = []
for _ in range(1 + cost_dim):
net1 = Net(
diff --git a/fsrl/agent/trpo_lag_agent.py b/fsrl/agent/trpo_lag_agent.py
index d617320..7f12741 100644
--- a/fsrl/agent/trpo_lag_agent.py
+++ b/fsrl/agent/trpo_lag_agent.py
@@ -116,7 +116,6 @@ def __init__(
else:
cost_dim = len(cost_limit)
-
# set seed and computing
seed_all(seed)
torch.set_num_threads(thread)
diff --git a/fsrl/data/fast_collector.py b/fsrl/data/fast_collector.py
index 503ccf3..7edfa4f 100644
--- a/fsrl/data/fast_collector.py
+++ b/fsrl/data/fast_collector.py
@@ -76,7 +76,7 @@ def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None:
else: # ReplayBuffer or PrioritizedReplayBuffer
assert buffer.maxsize > 0
if self.env_num > 1:
- if type(buffer) == ReplayBuffer:
+ if isinstance(buffer) == ReplayBuffer:
buffer_type = "ReplayBuffer"
vector_type = "VectorReplayBuffer"
else:
diff --git a/fsrl/policy/cvpo.py b/fsrl/policy/cvpo.py
index 578686c..3db1300 100644
--- a/fsrl/policy/cvpo.py
+++ b/fsrl/policy/cvpo.py
@@ -124,10 +124,12 @@ def __init__(
self.dtype = next(self.actor.parameters()).dtype
self.cost_limit = [cost_limit] * (self.critics_num -
1) if np.isscalar(cost_limit) else cost_limit
+
+ self.max_episode_steps = max_episode_steps
# qc threshold in the E-step
self.qc_thres = [
- c * (1 - self._gamma**max_episode_steps) / (1 - self._gamma) /
- max_episode_steps for c in self.cost_limit
+ c * (1 - self._gamma**self.max_episode_steps) / (1 - self._gamma) /
+ self.max_episode_steps for c in self.cost_limit
]
# E-step init
@@ -169,8 +171,8 @@ def update_cost_limit(self, cost_limit: float):
1) if np.isscalar(cost_limit) else cost_limit
self.qc_thres = [
- c * (1 - self._gamma**max_episode_steps) / (1 - self._gamma) /
- max_episode_steps for c in self.cost_limit
+ c * (1 - self._gamma**self.max_episode_steps) / (1 - self._gamma) /
+ self.max_episode_steps for c in self.cost_limit
]
def pre_update_fn(self, **kwarg: Any) -> Any: