Skip to content

Commit

Permalink
Bugs fixed for Multi-cost safe RL (#5)
Browse files Browse the repository at this point in the history
* Bug fixed for multi-cost cvpo

* Update cvpo.py

* Update cpo_agent.py

* Update ddpg_lag_agent.py

* Update ppo_lag_agent.py

* Update ppo_lag_agent.py

* Update sac_lag_agent.py

* Update sac_lag_agent.py

* Update trpo_lag_agent.py

* Update cpo_agent.py

* Update cpo_agent.py

* Update train_cvpo_agent.py

* Update train_ddpgl_agent.py

* Update train_focops_agent.py

* Update train_ppol_agent.py

* Update train_sacl_agent.py

* Update train_trpol_agent.py

* Update cpo_cfg.py

* Update cvpo_cfg.py

* Update ddpgl_cfg.py

* Update focosp_cfg.py

* Update ppol_cfg.py

* Update sacl_cfg.py

* Update trpol_cfg.py

* Update cvpo_agent.py

* Update ddpg_lag_agent.py
  • Loading branch information
yihangyao authored Oct 3, 2023
1 parent 8ecb0e3 commit 584c65f
Show file tree
Hide file tree
Showing 20 changed files with 58 additions and 19 deletions.
2 changes: 1 addition & 1 deletion examples/mlp/train_cvpo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def train(args: TrainCfg):
if args.name is None:
args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
if args.group is None:
args.group = args.task + "-cost-" + str(int(args.cost_limit))
args.group = args.task + "-cost-" + str(args.cost_limit)
if args.logdir is not None:
args.logdir = os.path.join(args.logdir, args.project, args.group)
logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
Expand Down
2 changes: 1 addition & 1 deletion examples/mlp/train_ddpgl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def train(args: TrainCfg):
if args.name is None:
args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
if args.group is None:
args.group = args.task + "-cost-" + str(int(args.cost_limit))
args.group = args.task + "-cost-" + str(args.cost_limit)
if args.logdir is not None:
args.logdir = os.path.join(args.logdir, args.project, args.group)
logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
Expand Down
2 changes: 1 addition & 1 deletion examples/mlp/train_focops_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def train(args: TrainCfg):
if args.name is None:
args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
if args.group is None:
args.group = args.task + "-cost-" + str(int(args.cost_limit))
args.group = args.task + "-cost-" + str(args.cost_limit)
if args.logdir is not None:
args.logdir = os.path.join(args.logdir, args.project, args.group)
logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
Expand Down
2 changes: 1 addition & 1 deletion examples/mlp/train_ppol_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def train(args: TrainCfg):
if args.name is None:
args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
if args.group is None:
args.group = args.task + "-cost-" + str(int(args.cost_limit))
args.group = args.task + "-cost-" + str(args.cost_limit)
if args.logdir is not None:
args.logdir = os.path.join(args.logdir, args.project, args.group)
logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
Expand Down
2 changes: 1 addition & 1 deletion examples/mlp/train_sacl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def train(args: TrainCfg):
if args.name is None:
args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
if args.group is None:
args.group = args.task + "-cost-" + str(int(args.cost_limit))
args.group = args.task + "-cost-" + str(args.cost_limit)
if args.logdir is not None:
args.logdir = os.path.join(args.logdir, args.project, args.group)
logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
Expand Down
2 changes: 1 addition & 1 deletion examples/mlp/train_trpol_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def train(args: TrainCfg):
if args.name is None:
args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
if args.group is None:
args.group = args.task + "-cost-" + str(int(args.cost_limit))
args.group = args.task + "-cost-" + str(args.cost_limit)
if args.logdir is not None:
args.logdir = os.path.join(args.logdir, args.project, args.group)
logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
Expand Down
5 changes: 5 additions & 0 deletions fsrl/agent/cpo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def __init__(
self.logger = logger
self.cost_limit = cost_limit

if np.isscalar(cost_limit):
cost_dim = 1
else:
raise RuntimeError("CPO does not support multiple costs. \n Please refer to Page 5 of http://proceedings.mlr.press/v70/achiam17a/achiam17a.pdf for related discussions.")

# set seed and computing
seed_all(seed)
torch.set_num_threads(thread)
Expand Down
8 changes: 7 additions & 1 deletion fsrl/agent/cvpo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb
from torch.distributions import Independent, Normal
Expand Down Expand Up @@ -117,6 +118,11 @@ def __init__(

self.logger = logger
self.cost_limit = cost_limit

if np.isscalar(cost_limit):
cost_dim = 1
else:
cost_dim = len(cost_limit)

# set seed and computing
seed_all(seed)
Expand Down Expand Up @@ -144,7 +150,7 @@ def __init__(

critics = []

for _ in range(2):
for _ in range(1+cost_dim):
if double_critic:
net1 = Net(
state_shape,
Expand Down
9 changes: 8 additions & 1 deletion fsrl/agent/ddpg_lag_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
from tianshou.exploration import GaussianNoise
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic
Expand Down Expand Up @@ -106,14 +107,20 @@ def __init__(
net = Net(state_shape, hidden_sizes=hidden_sizes, device=device)
actor = Actor(net, action_shape, max_action=max_action, device=device).to(device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=actor_lr)

if np.isscalar(cost_limit):
cost_dim = 1
else:
cost_dim = len(cost_limit)

nets = [
Net(
state_shape,
action_shape,
hidden_sizes=hidden_sizes,
concat=True,
device=device
) for i in range(2)
) for i in range(cost_dim + 1)
]
critic = [Critic(n, device=device).to(device) for n in nets]
critic_optim = torch.optim.Adam(nn.ModuleList(critic).parameters(), lr=critic_lr)
Expand Down
7 changes: 6 additions & 1 deletion fsrl/agent/ppo_lag_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ def __init__(
self.logger = logger
self.cost_limit = cost_limit

if np.isscalar(cost_limit):
cost_dim = 1
else:
cost_dim = len(cost_limit)

# set seed and computing
seed_all(seed)
torch.set_num_threads(thread)
Expand All @@ -136,7 +141,7 @@ def __init__(
Critic(
Net(state_shape, hidden_sizes=hidden_sizes, device=device),
device=device
).to(device) for _ in range(2)
).to(device) for _ in range(1 + cost_dim)
]

torch.nn.init.constant_(actor.sigma_param, -0.5)
Expand Down
9 changes: 7 additions & 2 deletions fsrl/agent/sac_lag_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ def __init__(
self.logger = logger
self.cost_limit = cost_limit

if np.isscalar(cost_limit):
cost_dim = 1
else:
cost_dim = len(cost_limit)

# set seed and computing
seed_all(seed)
torch.set_num_threads(thread)
Expand All @@ -128,9 +133,9 @@ def __init__(
unbounded=unbounded
).to(device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=actor_lr)

critics = []
for _ in range(2):
for _ in range(1 + cost_dim):
net1 = Net(
state_shape,
action_shape,
Expand Down
8 changes: 7 additions & 1 deletion fsrl/agent/trpo_lag_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ def __init__(
self.logger = logger
self.cost_limit = cost_limit

if np.isscalar(cost_limit):
cost_dim = 1
else:
cost_dim = len(cost_limit)


# set seed and computing
seed_all(seed)
torch.set_num_threads(thread)
Expand All @@ -128,7 +134,7 @@ def __init__(
Critic(
Net(state_shape, hidden_sizes=hidden_sizes, device=device),
device=device
).to(device) for _ in range(2)
).to(device) for _ in range(1 + cost_dim)
]

torch.nn.init.constant_(actor.sigma_param, -0.5)
Expand Down
2 changes: 1 addition & 1 deletion fsrl/config/cpo_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class Bullet10MCfg(TrainCfg):

@dataclass
class MujocoBaseCfg(TrainCfg):
task: str = "SafetyPointCircle1-v0"
task: str = "SafetyPointCircle1Gymnasium-v0"
epoch: int = 250
cost_limit: float = 25
# collecting params
Expand Down
2 changes: 1 addition & 1 deletion fsrl/config/cvpo_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class Bullet10MCfg(TrainCfg):

@dataclass
class MujocoBaseCfg(TrainCfg):
task: str = "SafetyPointCircle1-v0"
task: str = "SafetyPointCircle1Gymnasium-v0"
epoch: int = 250
cost_limit: float = 25
unbounded: bool = True
Expand Down
2 changes: 1 addition & 1 deletion fsrl/config/ddpgl_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class Bullet10MCfg(TrainCfg):

@dataclass
class MujocoBaseCfg(TrainCfg):
task: str = "SafetyPointCircle1-v0"
task: str = "SafetyPointCircle1Gymnasium-v0"
epoch: int = 250
cost_limit: float = 25
gamma: float = 0.99
Expand Down
2 changes: 1 addition & 1 deletion fsrl/config/focosp_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class Bullet10MCfg(TrainCfg):

@dataclass
class MujocoBaseCfg(TrainCfg):
task: str = "SafetyPointCircle1-v0"
task: str = "SafetyPointCircle1Gymnasium-v0"
epoch: int = 250
cost_limit: float = 25
# collecting params
Expand Down
2 changes: 1 addition & 1 deletion fsrl/config/ppol_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class Bullet10MCfg(TrainCfg):

@dataclass
class MujocoBaseCfg(TrainCfg):
task: str = "SafetyPointCircle1-v0"
task: str = "SafetyPointCircle1Gymnasium-v0"
epoch: int = 250
cost_limit: float = 25
# collecting params
Expand Down
2 changes: 1 addition & 1 deletion fsrl/config/sacl_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class Bullet10MCfg(TrainCfg):

@dataclass
class MujocoBaseCfg(TrainCfg):
task: str = "SafetyPointCircle1-v0"
task: str = "SafetyPointCircle1Gymnasium-v0"
epoch: int = 250
cost_limit: float = 25
gamma: float = 0.99
Expand Down
2 changes: 1 addition & 1 deletion fsrl/config/trpol_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class Bullet10MCfg(TrainCfg):

@dataclass
class MujocoBaseCfg(TrainCfg):
task: str = "SafetyPointCircle1-v0"
task: str = "SafetyPointCircle1Gymnasium-v0"
epoch: int = 250
cost_limit: float = 25
# collecting params
Expand Down
5 changes: 5 additions & 0 deletions fsrl/policy/cvpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ def update_cost_limit(self, cost_limit: float):
self.cost_limit = [cost_limit] * (self.critics_num -
1) if np.isscalar(cost_limit) else cost_limit

self.qc_thres = [
c * (1 - self._gamma**max_episode_steps) / (1 - self._gamma) /
max_episode_steps for c in self.cost_limit
]

def pre_update_fn(self, **kwarg: Any) -> Any:
"""Init the mstep optimizer and dual variables."""
self.mstep_dual_mu = torch.zeros(
Expand Down

0 comments on commit 584c65f

Please sign in to comment.