Skip to content

Commit

Permalink
feature(rjy): add HAPPO algorithm (#717)
Browse files Browse the repository at this point in the history
* model(rjy): add vac model for HAPPO

* test(rjy): polish havac and add test

* polish(rjy): fix conflict

* polish(rjy): add hidden_state for ac

* feature(rjy): change the havac to multiagent model

* feature(rjy): add happo forward_learn

* feature(rjy): modify the happo_data

* test(rjy): add happo data test

* feature(rjy): add HAPPO policy

* feature(rjy): try to fit mu-mujoco

* polish(rjy): Change code to adapt to mujoco

* fix(rjy): fix the distribution in ppo update

* fix(rjy): fix the happo+mujoco

* config(rjy): add walker+happo config

* polish(rjy): separate actors and critics

* polish(rjy): polish according to comments

* polish(rjy): fix the pipeline

* polish(rjy): fix the style

* polish(rjy): polish according to comments

* polish(rjy): fix style

* polish(rjy): fix style

* polish(rjy): fix style

* polish(rjy): seperate the happo model

* fix(rjy): fix happo model style

* polish(rjy): polish happo policy comments

* polish(rjy): polish happo comments
  • Loading branch information
nighood authored Jan 11, 2024
1 parent 6994a67 commit 4738444
Show file tree
Hide file tree
Showing 14 changed files with 2,122 additions and 6 deletions.
13 changes: 11 additions & 2 deletions ding/model/common/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,8 +1102,9 @@ class ReparameterizationHead(nn.Module):
Interfaces:
``__init__``, ``forward``.
"""

default_sigma_type = ['fixed', 'independent', 'conditioned']
# The "happo" type here is to align with the sigma initialization method of the network in the original HAPPO \
# paper. The code here needs to be optimized later.
default_sigma_type = ['fixed', 'independent', 'conditioned', 'happo']
default_bound_type = ['tanh', None]

def __init__(
Expand Down Expand Up @@ -1155,6 +1156,11 @@ def __init__(
self.log_sigma_param = nn.Parameter(torch.zeros(1, output_size))
elif self.sigma_type == 'conditioned':
self.log_sigma_layer = nn.Linear(hidden_size, output_size)
elif self.sigma_type == 'happo':
self.sigma_x_coef = 1.
self.sigma_y_coef = 0.5
# This parameter (x_coef, y_coef) refers to the HAPPO paper http://arxiv.org/abs/2109.11251.
self.log_sigma_param = nn.Parameter(torch.ones(1, output_size) * self.sigma_x_coef)

def forward(self, x: torch.Tensor) -> Dict:
"""
Expand Down Expand Up @@ -1190,6 +1196,9 @@ def forward(self, x: torch.Tensor) -> Dict:
elif self.sigma_type == 'conditioned':
log_sigma = self.log_sigma_layer(x)
sigma = torch.exp(torch.clamp(log_sigma, -20, 2))
elif self.sigma_type == 'happo':
log_sigma = self.log_sigma_param + torch.zeros_like(mu)
sigma = torch.sigmoid(log_sigma / self.sigma_x_coef) * self.sigma_y_coef
return {'mu': mu, 'sigma': sigma}


Expand Down
1 change: 1 addition & 0 deletions ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
from .bcq import BCQ
from .edac import EDAC
from .ebm import EBM, AutoregressiveEBM
from .havac import HAVAC
500 changes: 500 additions & 0 deletions ding/model/template/havac.py

Large diffs are not rendered by default.

103 changes: 103 additions & 0 deletions ding/model/template/tests/test_havac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest
import torch
import random
from ding.torch_utils import is_differentiable
from ding.model.template import HAVAC


@pytest.mark.unittest
class TestHAVAC:

def test_havac_rnn_actor(self):
# discrete+rnn
bs, T = 3, 8
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
agent_num = 5
data = {
'obs': {
'agent_state': torch.randn(T, bs, obs_dim),
'global_state': torch.randn(T, bs, global_obs_dim),
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim))
},
'actor_prev_state': [None for _ in range(bs)],
}
model = HAVAC(
agent_obs_shape=obs_dim,
global_obs_shape=global_obs_dim,
action_shape=action_dim,
agent_num=agent_num,
use_lstm=True,
)
agent_idx = random.randint(0, agent_num - 1)
output = model(agent_idx, data, mode='compute_actor')
assert set(output.keys()) == set(['logit', 'actor_next_state', 'actor_hidden_state'])
assert output['logit'].shape == (T, bs, action_dim)
assert len(output['actor_next_state']) == bs
print(output['actor_next_state'][0]['h'].shape)
loss = output['logit'].sum()
is_differentiable(loss, model.agent_models[agent_idx].actor)

def test_havac_rnn_critic(self):
# discrete+rnn
bs, T = 3, 8
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
agent_num = 5
data = {
'obs': {
'agent_state': torch.randn(T, bs, obs_dim),
'global_state': torch.randn(T, bs, global_obs_dim),
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim))
},
'critic_prev_state': [None for _ in range(bs)],
}
model = HAVAC(
agent_obs_shape=obs_dim,
global_obs_shape=global_obs_dim,
action_shape=action_dim,
agent_num=agent_num,
use_lstm=True,
)
agent_idx = random.randint(0, agent_num - 1)
output = model(agent_idx, data, mode='compute_critic')
assert set(output.keys()) == set(['value', 'critic_next_state', 'critic_hidden_state'])
assert output['value'].shape == (T, bs)
assert len(output['critic_next_state']) == bs
print(output['critic_next_state'][0]['h'].shape)
loss = output['value'].sum()
is_differentiable(loss, model.agent_models[agent_idx].critic)

def test_havac_rnn_actor_critic(self):
# discrete+rnn
bs, T = 3, 8
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
agent_num = 5
data = {
'obs': {
'agent_state': torch.randn(T, bs, obs_dim),
'global_state': torch.randn(T, bs, global_obs_dim),
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim))
},
'actor_prev_state': [None for _ in range(bs)],
'critic_prev_state': [None for _ in range(bs)],
}
model = HAVAC(
agent_obs_shape=obs_dim,
global_obs_shape=global_obs_dim,
action_shape=action_dim,
agent_num=agent_num,
use_lstm=True,
)
agent_idx = random.randint(0, agent_num - 1)
output = model(agent_idx, data, mode='compute_actor_critic')
assert set(output.keys()) == set(
['logit', 'actor_next_state', 'actor_hidden_state', 'value', 'critic_next_state', 'critic_hidden_state']
)
assert output['logit'].shape == (T, bs, action_dim)
assert output['value'].shape == (T, bs)
loss = output['logit'].sum() + output['value'].sum()
is_differentiable(loss, model.agent_models[agent_idx])


# test_havac_rnn_actor()
# test_havac_rnn_critic()
# test_havac_rnn_actor_critic()
1 change: 1 addition & 0 deletions ding/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@
# new-type policy
from .ppof import PPOFPolicy
from .prompt_pg import PromptPGPolicy
from .happo import HAPPOPolicy
6 changes: 6 additions & 0 deletions ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from .edac import EDACPolicy
from .prompt_pg import PromptPGPolicy
from .plan_diffuser import PDPolicy
from .happo import HAPPOPolicy


class EpsCommandModePolicy(CommandModePolicy):
Expand Down Expand Up @@ -186,6 +187,11 @@ class PPOCommandModePolicy(PPOPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('happo_command')
class HAPPOCommandModePolicy(HAPPOPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('ppo_stdim_command')
class PPOSTDIMCommandModePolicy(PPOSTDIMPolicy, DummyCommandModePolicy):
pass
Expand Down
Loading

0 comments on commit 4738444

Please sign in to comment.