You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi @Toni-SM , thanks for the great library! I'm trying to train an agent with SAC_RNN, but I've been getting extremely poor performance in comparison to when I used SAC with a standard MLP and I'm really not sure why. In theory, my problem should only benefit from using an RNN since it's partially observable. In particular, I've had three main issues:
Lack of learning, no improvement in success rate, no improvement in rewards (comparison plots below):
RNN (these plots are from an earlier timestep, but the behavior persists for the entire duration):
MLP:
Extremely slow training – I typically get a speed of about 5-10 it/s with the standard NN, but it goes down to ~1 it/s with the RNN even if I reduce the number of environments. Is this to be expected?
Possible bug with random timesteps – if I try to enable random timesteps with the RNN setup, I get the following error when the random timesteps are over and learning starts:
Traceback (most recent call last):
File "/isaac-sim/kit/python/lib/python3.10/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
return func()
File "/isaac-sim/kit/python/lib/python3.10/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
lambda: hydra.run(
File "/isaac-sim/kit/python/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 132, in run
_ = ret.return_value
File "/isaac-sim/kit/python/lib/python3.10/site-packages/hydra/core/utils.py", line 260, in return_value
raise self._return_value
File "/isaac-sim/kit/python/lib/python3.10/site-packages/hydra/core/utils.py", line 186, in run_job
ret.return_value = task_function(task_cfg)
File "/isaaclab/source/isaaclab_tasks/isaaclab_tasks/utils/hydra.py", line 101, in hydra_main
func(env_cfg, agent_cfg, *args, **kwargs)
File "/blind_manipulation/scripts/skrl/orchestrator_sac.py", line 123, in main
trainer.train(env_cfg)
File "/blind_manipulation/scripts/skrl/trainer_class_sac.py", line 196, in train
trainer.train()
File "/isaac-sim/kit/python/lib/python3.10/site-packages/skrl/trainers/torch/sequential.py", line 86, in train
self.single_agent_train()
File "/isaac-sim/kit/python/lib/python3.10/site-packages/skrl/trainers/torch/base.py", line 193, in single_agent_train
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]
File "/isaac-sim/kit/python/lib/python3.10/site-packages/skrl/agents/torch/sac/sac_rnn.py", line 274, in act
actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states), **rnn}, role="policy")
File "/isaac-sim/kit/python/lib/python3.10/site-packages/skrl/models/torch/gaussian.py", line 129, in act
mean_actions, log_std, outputs = self.compute(inputs, role)
File "/blind_manipulation/scripts/skrl/sac_rnn.py", line 50, in compute
hidden_states = inputs["rnn"][0]
IndexError: list index out of range
Pretty much everything (observations, SAC hyperparams, etc) is kept constant between training the MLP policy and the RNN policy, the only difference is the policy implementation itself using SKRL. My RNN implementation is below:
class StochasticActor(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum",
num_envs=128, num_layers=1, hidden_size=128, sequence_length=16):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
self.num_envs = num_envs
self.sequence_length = sequence_length
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn = nn.GRU(input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True)
self.net = nn.Sequential(nn.Linear(self.hidden_size, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, self.num_actions),
nn.Tanh())
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
def get_specification(self):
return {"rnn": {"sequence_length": self.sequence_length,
"sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}}
def compute(self, inputs, role):
states = inputs["states"]
terminated = inputs.get("terminated", None)
# add truncated to terminated
truncated = inputs.get("truncated", None)
terminated = terminated | truncated if terminated is not None and truncated is not None else terminated
hidden_states = inputs["rnn"][0]
if self.training:
rnn_input = states.view(-1, self.sequence_length, states.shape[-1])
hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length,
hidden_states.shape[-1])
hidden_states = hidden_states[:,:,0,:].contiguous()
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i+1]
rnn_output, hidden_states = self.rnn(rnn_input[:, i0:i1, :], hidden_states)
hidden_states[:, (terminated[:,i1-1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_output = torch.cat(rnn_outputs, dim=1)
else:
rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
else:
rnn_input = states.view(-1, 1, states.shape[-1])
rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1)
return self.net(rnn_output), self.log_std_parameter, {"rnn": [hidden_states]}
class Critic(DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
num_envs=128, num_layers=1, hidden_size=128, sequence_length=16):
Model.__init__(self, observation_space, action_space, device)
DeterministicMixin.__init__(self, clip_actions)
self.num_envs = num_envs
self.sequence_length = sequence_length
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn = nn.GRU(input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True)
self.net = nn.Sequential(nn.Linear(self.hidden_size, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 1))
def get_specification(self):
return {"rnn": {"sequence_length": self.sequence_length,
"sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}}
def compute(self, inputs, role):
states = inputs["states"]
terminated = inputs.get("terminated", None)
truncated = inputs.get("truncated", None)
terminated = terminated | truncated if terminated is not None and truncated is not None else terminated
hidden_states = inputs["rnn"][0]
# critic is only used during training
rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
# get the hidden states corresponding to the initial sequence
sequence_index = 1 if role in ["target_critic_1", "target_critic_2"] else 0 # target networks act on the next state of the environment
hidden_states = hidden_states[:,:,sequence_index,:].contiguous() # (D * num_layers, N, Hout)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, hidden_states = self.rnn(rnn_input[:,i0:i1,:], hidden_states)
hidden_states[:, (terminated[:,i1-1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
return self.net(rnn_output), {"rnn": [hidden_states]}
cfg = SAC_DEFAULT_CONFIG.copy()
cfg["gradient_steps"] = 10
cfg["batch_size"] = 64
cfg["discount_factor"] = 0.99
cfg["polyak"] = 0.05
cfg["actor_learning_rate"] = 0.001
cfg["critic_learning_rate"] = 0.001
cfg["learning_rate_scheduler"] = KLAdaptiveLR
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01}
cfg["random_timesteps"] = 0
cfg["learning_starts"] = 7000
cfg["grad_norm_clip"] = 0
cfg["learn_entropy"] = True
cfg["entropy_learning_rate"] = 3e-4
cfg["initial_entropy_value"] = 0.005
cfg["seed"] = 500
cfg["experiment"]["directory"] = "collision_free_nav_ur10"
cfg["experiment"]["experiment_name"] = ""
cfg["experiment"]["write_interval"] = "auto"
cfg["experiment"]["checkpoint_interval"] = "auto"
cfg["experiment"]["wandb"] = True
cfg["experiment"]["wandb_kwargs"]["project"] = "collision-free-nav-ur10"
cfg["experiment"]["wandb_kwargs"]["entity"] = "mukundkk-cmu"
cfg["experiment"]["wandb_kwargs"]["name"] = "debug-run"
cfg["experiment"]["wandb_kwargs"]["resume"] = "must"
cfg_trainer = {"timesteps": 500000, "close_environment_at_exit": False, "environment_info": "log"}
And the agent instantiation in a separate file for reference:
Sorry for the super long post 😅 I'd really appreciate any help/advice on how to debug this (or if there are issues with my code as is)! Please let me know if you need any other information.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi @Toni-SM , thanks for the great library! I'm trying to train an agent with SAC_RNN, but I've been getting extremely poor performance in comparison to when I used SAC with a standard MLP and I'm really not sure why. In theory, my problem should only benefit from using an RNN since it's partially observable. In particular, I've had three main issues:
RNN (these plots are from an earlier timestep, but the behavior persists for the entire duration):
MLP:


Extremely slow training – I typically get a speed of about 5-10 it/s with the standard NN, but it goes down to ~1 it/s with the RNN even if I reduce the number of environments. Is this to be expected?
Possible bug with random timesteps – if I try to enable random timesteps with the RNN setup, I get the following error when the random timesteps are over and learning starts:
Pretty much everything (observations, SAC hyperparams, etc) is kept constant between training the MLP policy and the RNN policy, the only difference is the policy implementation itself using SKRL. My RNN implementation is below:
And the agent instantiation in a separate file for reference:
Sorry for the super long post 😅 I'd really appreciate any help/advice on how to debug this (or if there are issues with my code as is)! Please let me know if you need any other information.
Beta Was this translation helpful? Give feedback.
All reactions