-
Notifications
You must be signed in to change notification settings - Fork 371
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(rjy): add HAPPO algorithm (#717)
* model(rjy): add vac model for HAPPO * test(rjy): polish havac and add test * polish(rjy): fix conflict * polish(rjy): add hidden_state for ac * feature(rjy): change the havac to multiagent model * feature(rjy): add happo forward_learn * feature(rjy): modify the happo_data * test(rjy): add happo data test * feature(rjy): add HAPPO policy * feature(rjy): try to fit mu-mujoco * polish(rjy): Change code to adapt to mujoco * fix(rjy): fix the distribution in ppo update * fix(rjy): fix the happo+mujoco * config(rjy): add walker+happo config * polish(rjy): separate actors and critics * polish(rjy): polish according to comments * polish(rjy): fix the pipeline * polish(rjy): fix the style * polish(rjy): polish according to comments * polish(rjy): fix style * polish(rjy): fix style * polish(rjy): fix style * polish(rjy): seperate the happo model * fix(rjy): fix happo model style * polish(rjy): polish happo policy comments * polish(rjy): polish happo comments
- Loading branch information
Showing
14 changed files
with
2,122 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,3 +27,4 @@ | |
from .bcq import BCQ | ||
from .edac import EDAC | ||
from .ebm import EBM, AutoregressiveEBM | ||
from .havac import HAVAC |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import pytest | ||
import torch | ||
import random | ||
from ding.torch_utils import is_differentiable | ||
from ding.model.template import HAVAC | ||
|
||
|
||
@pytest.mark.unittest | ||
class TestHAVAC: | ||
|
||
def test_havac_rnn_actor(self): | ||
# discrete+rnn | ||
bs, T = 3, 8 | ||
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 | ||
agent_num = 5 | ||
data = { | ||
'obs': { | ||
'agent_state': torch.randn(T, bs, obs_dim), | ||
'global_state': torch.randn(T, bs, global_obs_dim), | ||
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) | ||
}, | ||
'actor_prev_state': [None for _ in range(bs)], | ||
} | ||
model = HAVAC( | ||
agent_obs_shape=obs_dim, | ||
global_obs_shape=global_obs_dim, | ||
action_shape=action_dim, | ||
agent_num=agent_num, | ||
use_lstm=True, | ||
) | ||
agent_idx = random.randint(0, agent_num - 1) | ||
output = model(agent_idx, data, mode='compute_actor') | ||
assert set(output.keys()) == set(['logit', 'actor_next_state', 'actor_hidden_state']) | ||
assert output['logit'].shape == (T, bs, action_dim) | ||
assert len(output['actor_next_state']) == bs | ||
print(output['actor_next_state'][0]['h'].shape) | ||
loss = output['logit'].sum() | ||
is_differentiable(loss, model.agent_models[agent_idx].actor) | ||
|
||
def test_havac_rnn_critic(self): | ||
# discrete+rnn | ||
bs, T = 3, 8 | ||
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 | ||
agent_num = 5 | ||
data = { | ||
'obs': { | ||
'agent_state': torch.randn(T, bs, obs_dim), | ||
'global_state': torch.randn(T, bs, global_obs_dim), | ||
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) | ||
}, | ||
'critic_prev_state': [None for _ in range(bs)], | ||
} | ||
model = HAVAC( | ||
agent_obs_shape=obs_dim, | ||
global_obs_shape=global_obs_dim, | ||
action_shape=action_dim, | ||
agent_num=agent_num, | ||
use_lstm=True, | ||
) | ||
agent_idx = random.randint(0, agent_num - 1) | ||
output = model(agent_idx, data, mode='compute_critic') | ||
assert set(output.keys()) == set(['value', 'critic_next_state', 'critic_hidden_state']) | ||
assert output['value'].shape == (T, bs) | ||
assert len(output['critic_next_state']) == bs | ||
print(output['critic_next_state'][0]['h'].shape) | ||
loss = output['value'].sum() | ||
is_differentiable(loss, model.agent_models[agent_idx].critic) | ||
|
||
def test_havac_rnn_actor_critic(self): | ||
# discrete+rnn | ||
bs, T = 3, 8 | ||
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 | ||
agent_num = 5 | ||
data = { | ||
'obs': { | ||
'agent_state': torch.randn(T, bs, obs_dim), | ||
'global_state': torch.randn(T, bs, global_obs_dim), | ||
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) | ||
}, | ||
'actor_prev_state': [None for _ in range(bs)], | ||
'critic_prev_state': [None for _ in range(bs)], | ||
} | ||
model = HAVAC( | ||
agent_obs_shape=obs_dim, | ||
global_obs_shape=global_obs_dim, | ||
action_shape=action_dim, | ||
agent_num=agent_num, | ||
use_lstm=True, | ||
) | ||
agent_idx = random.randint(0, agent_num - 1) | ||
output = model(agent_idx, data, mode='compute_actor_critic') | ||
assert set(output.keys()) == set( | ||
['logit', 'actor_next_state', 'actor_hidden_state', 'value', 'critic_next_state', 'critic_hidden_state'] | ||
) | ||
assert output['logit'].shape == (T, bs, action_dim) | ||
assert output['value'].shape == (T, bs) | ||
loss = output['logit'].sum() + output['value'].sum() | ||
is_differentiable(loss, model.agent_models[agent_idx]) | ||
|
||
|
||
# test_havac_rnn_actor() | ||
# test_havac_rnn_critic() | ||
# test_havac_rnn_actor_critic() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.