From 4738444357b3836a78e56946b080251663e2c906 Mon Sep 17 00:00:00 2001 From: Ren Jiyuan <47732381+nighood@users.noreply.github.com> Date: Thu, 11 Jan 2024 12:11:03 +0800 Subject: [PATCH] feature(rjy): add HAPPO algorithm (#717) * model(rjy): add vac model for HAPPO * test(rjy): polish havac and add test * polish(rjy): fix conflict * polish(rjy): add hidden_state for ac * feature(rjy): change the havac to multiagent model * feature(rjy): add happo forward_learn * feature(rjy): modify the happo_data * test(rjy): add happo data test * feature(rjy): add HAPPO policy * feature(rjy): try to fit mu-mujoco * polish(rjy): Change code to adapt to mujoco * fix(rjy): fix the distribution in ppo update * fix(rjy): fix the happo+mujoco * config(rjy): add walker+happo config * polish(rjy): separate actors and critics * polish(rjy): polish according to comments * polish(rjy): fix the pipeline * polish(rjy): fix the style * polish(rjy): polish according to comments * polish(rjy): fix style * polish(rjy): fix style * polish(rjy): fix style * polish(rjy): seperate the happo model * fix(rjy): fix happo model style * polish(rjy): polish happo policy comments * polish(rjy): polish happo comments --- ding/model/common/head.py | 13 +- ding/model/template/__init__.py | 1 + ding/model/template/havac.py | 500 ++++++++++++ ding/model/template/tests/test_havac.py | 103 +++ ding/policy/__init__.py | 1 + ding/policy/command_mode_policy_instance.py | 6 + ding/policy/happo.py | 734 ++++++++++++++++++ ding/rl_utils/__init__.py | 10 +- ding/rl_utils/happo.py | 347 +++++++++ ding/rl_utils/tests/test_happo.py | 71 ++ .../config/halfcheetah_happo_config.py | 83 ++ .../config/halfcheetah_mappo_config.py | 80 ++ .../config/walker2d_happo_config.py | 91 +++ .../config/ptz_simple_spread_happo_config.py | 88 +++ 14 files changed, 2122 insertions(+), 6 deletions(-) create mode 100644 ding/model/template/havac.py create mode 100644 ding/model/template/tests/test_havac.py create mode 100644 ding/policy/happo.py create mode 100644 ding/rl_utils/happo.py create mode 100644 ding/rl_utils/tests/test_happo.py create mode 100644 dizoo/multiagent_mujoco/config/halfcheetah_happo_config.py create mode 100644 dizoo/multiagent_mujoco/config/halfcheetah_mappo_config.py create mode 100644 dizoo/multiagent_mujoco/config/walker2d_happo_config.py create mode 100644 dizoo/petting_zoo/config/ptz_simple_spread_happo_config.py diff --git a/ding/model/common/head.py b/ding/model/common/head.py index 99e94a85b1..1131e8a2e8 100755 --- a/ding/model/common/head.py +++ b/ding/model/common/head.py @@ -1102,8 +1102,9 @@ class ReparameterizationHead(nn.Module): Interfaces: ``__init__``, ``forward``. """ - - default_sigma_type = ['fixed', 'independent', 'conditioned'] + # The "happo" type here is to align with the sigma initialization method of the network in the original HAPPO \ + # paper. The code here needs to be optimized later. + default_sigma_type = ['fixed', 'independent', 'conditioned', 'happo'] default_bound_type = ['tanh', None] def __init__( @@ -1155,6 +1156,11 @@ def __init__( self.log_sigma_param = nn.Parameter(torch.zeros(1, output_size)) elif self.sigma_type == 'conditioned': self.log_sigma_layer = nn.Linear(hidden_size, output_size) + elif self.sigma_type == 'happo': + self.sigma_x_coef = 1. + self.sigma_y_coef = 0.5 + # This parameter (x_coef, y_coef) refers to the HAPPO paper http://arxiv.org/abs/2109.11251. + self.log_sigma_param = nn.Parameter(torch.ones(1, output_size) * self.sigma_x_coef) def forward(self, x: torch.Tensor) -> Dict: """ @@ -1190,6 +1196,9 @@ def forward(self, x: torch.Tensor) -> Dict: elif self.sigma_type == 'conditioned': log_sigma = self.log_sigma_layer(x) sigma = torch.exp(torch.clamp(log_sigma, -20, 2)) + elif self.sigma_type == 'happo': + log_sigma = self.log_sigma_param + torch.zeros_like(mu) + sigma = torch.sigmoid(log_sigma / self.sigma_x_coef) * self.sigma_y_coef return {'mu': mu, 'sigma': sigma} diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index 4a63c3dcc6..c9dc17791c 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -27,3 +27,4 @@ from .bcq import BCQ from .edac import EDAC from .ebm import EBM, AutoregressiveEBM +from .havac import HAVAC diff --git a/ding/model/template/havac.py b/ding/model/template/havac.py new file mode 100644 index 0000000000..77489ed517 --- /dev/null +++ b/ding/model/template/havac.py @@ -0,0 +1,500 @@ +from typing import Union, Dict, Optional +import torch +import torch.nn as nn + +from ding.torch_utils import get_lstm +from ding.utils import SequenceType, squeeze, MODEL_REGISTRY +from ding.model.template.q_learning import parallel_wrapper +from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, \ + FCEncoder, ConvEncoder + + +class RNNLayer(nn.Module): + + def __init__(self, lstm_type, input_size, hidden_size, res_link: bool = False): + super(RNNLayer, self).__init__() + self.rnn = get_lstm(lstm_type, input_size=input_size, hidden_size=hidden_size) + self.res_link = res_link + + def forward(self, x, prev_state, inference: bool = False): + """ + Forward pass of the RNN layer. + If inference is True, sequence length of input is set to 1. + If res_link is True, a residual link is added to the output. + """ + # x: obs_embedding + if self.res_link: + a = x + if inference: + x = x.unsqueeze(0) # for rnn input, put the seq_len of x as 1 instead of none. + # prev_state: DataType: List[Tuple[torch.Tensor]]; Initially, it is a list of None + x, next_state = self.rnn(x, prev_state) + x = x.squeeze(0) # to delete the seq_len dim to match head network input + if self.res_link: + x = x + a + return {'output': x, 'next_state': next_state} + else: + # lstm_embedding stores all hidden_state + lstm_embedding = [] + hidden_state_list = [] + for t in range(x.shape[0]): # T timesteps + # use x[t:t+1] but not x[t] can keep original dimension + output, prev_state = self.rnn(x[t:t + 1], prev_state) # output: (1,B, head_hidden_size) + lstm_embedding.append(output) + hidden_state = [p['h'] for p in prev_state] + # only keep ht, {list: x.shape[0]{Tensor:(1, batch_size, head_hidden_size)}} + hidden_state_list.append(torch.cat(hidden_state, dim=1)) + x = torch.cat(lstm_embedding, 0) # (T, B, head_hidden_size) + if self.res_link: + x = x + a + all_hidden_state = torch.cat(hidden_state_list, dim=0) + return {'output': x, 'next_state': prev_state, 'hidden_state': all_hidden_state} + + +@MODEL_REGISTRY.register('havac') +class HAVAC(nn.Module): + """ + Overview: + The HAVAC model of each agent for HAPPO. + Interfaces: + ``__init__``, ``forward`` + """ + mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] + + def __init__( + self, + agent_obs_shape: Union[int, SequenceType], + global_obs_shape: Union[int, SequenceType], + action_shape: Union[int, SequenceType], + agent_num: int, + use_lstm: bool = False, + lstm_type: str = 'gru', + encoder_hidden_size_list: SequenceType = [128, 128, 64], + actor_head_hidden_size: int = 64, + actor_head_layer_num: int = 2, + critic_head_hidden_size: int = 64, + critic_head_layer_num: int = 1, + action_space: str = 'discrete', + activation: Optional[nn.Module] = nn.ReLU(), + norm_type: Optional[str] = None, + sigma_type: Optional[str] = 'independent', + bound_type: Optional[str] = None, + res_link: bool = False, + ) -> None: + r""" + Overview: + Init the VAC Model for HAPPO according to arguments. + Arguments: + - agent_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for single agent. + - global_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for global agent + - action_shape (:obj:`Union[int, SequenceType]`): Action's space. + - agent_num (:obj:`int`): Number of agents. + - lstm_type (:obj:`str`): use lstm or gru, default to gru + - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` + - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. + - actor_head_layer_num (:obj:`int`): + The num of layers used in the network to compute Q value output for actor's nn. + - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. + - critic_head_layer_num (:obj:`int`): + The num of layers used in the network to compute Q value output for critic's nn. + - activation (:obj:`Optional[nn.Module]`): + The type of activation function to use in ``MLP`` the after ``layer_fn``, + if ``None`` then default set to ``nn.ReLU()`` + - norm_type (:obj:`Optional[str]`): + The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details` + - res_link (:obj:`bool`): use the residual link or not, default to False + """ + super(HAVAC, self).__init__() + self.agent_num = agent_num + self.agent_models = nn.ModuleList( + [ + HAVACAgent( + agent_obs_shape=agent_obs_shape, + global_obs_shape=global_obs_shape, + action_shape=action_shape, + use_lstm=use_lstm, + action_space=action_space, + ) for _ in range(agent_num) + ] + ) + + def forward(self, agent_idx, input_data, mode): + selected_agent_model = self.agent_models[agent_idx] + output = selected_agent_model(input_data, mode) + return output + + +class HAVACAgent(nn.Module): + """ + Overview: + The HAVAC model of each agent for HAPPO. + Interfaces: + ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``, ``compute_actor_critic`` + """ + mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] + + def __init__( + self, + agent_obs_shape: Union[int, SequenceType], + global_obs_shape: Union[int, SequenceType], + action_shape: Union[int, SequenceType], + use_lstm: bool = False, + lstm_type: str = 'gru', + encoder_hidden_size_list: SequenceType = [128, 128, 64], + actor_head_hidden_size: int = 64, + actor_head_layer_num: int = 2, + critic_head_hidden_size: int = 64, + critic_head_layer_num: int = 1, + action_space: str = 'discrete', + activation: Optional[nn.Module] = nn.ReLU(), + norm_type: Optional[str] = None, + sigma_type: Optional[str] = 'happo', + bound_type: Optional[str] = None, + res_link: bool = False, + ) -> None: + r""" + Overview: + Init the VAC Model for HAPPO according to arguments. + Arguments: + - agent_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for single agent. + - global_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for global agent + - action_shape (:obj:`Union[int, SequenceType]`): Action's space. + - lstm_type (:obj:`str`): use lstm or gru, default to gru + - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` + - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. + - actor_head_layer_num (:obj:`int`): + The num of layers used in the network to compute Q value output for actor's nn. + - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. + - critic_head_layer_num (:obj:`int`): + The num of layers used in the network to compute Q value output for critic's nn. + - activation (:obj:`Optional[nn.Module]`): + The type of activation function to use in ``MLP`` the after ``layer_fn``, + if ``None`` then default set to ``nn.ReLU()`` + - norm_type (:obj:`Optional[str]`): + The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details` + - res_link (:obj:`bool`): use the residual link or not, default to False + """ + super(HAVACAgent, self).__init__() + agent_obs_shape: int = squeeze(agent_obs_shape) + global_obs_shape: int = squeeze(global_obs_shape) + action_shape: int = squeeze(action_shape) + self.global_obs_shape, self.agent_obs_shape, self.action_shape = global_obs_shape, agent_obs_shape, action_shape + self.action_space = action_space + # Encoder Type + if isinstance(agent_obs_shape, int) or len(agent_obs_shape) == 1: + actor_encoder_cls = FCEncoder + elif len(agent_obs_shape) == 3: + actor_encoder_cls = ConvEncoder + else: + raise RuntimeError( + "not support obs_shape for pre-defined encoder: {}, please customize your own VAC". + format(agent_obs_shape) + ) + if isinstance(global_obs_shape, int) or len(global_obs_shape) == 1: + critic_encoder_cls = FCEncoder + elif len(global_obs_shape) == 3: + critic_encoder_cls = ConvEncoder + else: + raise RuntimeError( + "not support obs_shape for pre-defined encoder: {}, please customize your own VAC". + format(global_obs_shape) + ) + + # We directly connect the Head after a Liner layer instead of using the 3-layer FCEncoder. + # In SMAC task it can obviously improve the performance. + # Users can change the model according to their own needs. + self.actor_encoder = actor_encoder_cls( + obs_shape=agent_obs_shape, + hidden_size_list=encoder_hidden_size_list, + activation=activation, + norm_type=norm_type + ) + self.critic_encoder = critic_encoder_cls( + obs_shape=global_obs_shape, + hidden_size_list=encoder_hidden_size_list, + activation=activation, + norm_type=norm_type + ) + # RNN part + self.use_lstm = use_lstm + if self.use_lstm: + self.actor_rnn = RNNLayer( + lstm_type, + input_size=encoder_hidden_size_list[-1], + hidden_size=actor_head_hidden_size, + res_link=res_link + ) + self.critic_rnn = RNNLayer( + lstm_type, + input_size=encoder_hidden_size_list[-1], + hidden_size=critic_head_hidden_size, + res_link=res_link + ) + # Head Type + self.critic_head = RegressionHead( + critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type + ) + assert self.action_space in ['discrete', 'continuous'], self.action_space + if self.action_space == 'discrete': + self.actor_head = DiscreteHead( + actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type + ) + elif self.action_space == 'continuous': + self.actor_head = ReparameterizationHead( + actor_head_hidden_size, + action_shape, + actor_head_layer_num, + sigma_type=sigma_type, + activation=activation, + norm_type=norm_type, + bound_type=bound_type + ) + # must use list, not nn.ModuleList + self.actor = [self.actor_encoder, self.actor_rnn, self.actor_head] if self.use_lstm \ + else [self.actor_encoder, self.actor_head] + self.critic = [self.critic_encoder, self.critic_rnn, self.critic_head] if self.use_lstm \ + else [self.critic_encoder, self.critic_head] + # for convenience of call some apis(such as: self.critic.parameters()), but may cause + # misunderstanding when print(self) + self.actor = nn.ModuleList(self.actor) + self.critic = nn.ModuleList(self.critic) + + def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: + r""" + Overview: + Use encoded embedding tensor to predict output. + Parameter updates with VAC's MLPs forward setup. + Arguments: + Forward with ``'compute_actor'`` or ``'compute_critic'``: + - inputs (:obj:`torch.Tensor`): + The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. + Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``. + Returns: + - outputs (:obj:`Dict`): + Run with encoder and head. + + Forward with ``'compute_actor'``, Necessary Keys: + - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. + + Forward with ``'compute_critic'``, Necessary Keys: + - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. + Shapes: + - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N corresponding ``hidden_size`` + - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` + - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. + + Actor Examples: + >>> model = VAC(64,128) + >>> inputs = torch.randn(4, 64) + >>> actor_outputs = model(inputs,'compute_actor') + >>> assert actor_outputs['logit'].shape == torch.Size([4, 128]) + + Critic Examples: + >>> model = VAC(64,64) + >>> inputs = torch.randn(4, 64) + >>> critic_outputs = model(inputs,'compute_critic') + >>> critic_outputs['value'] + tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) + + Actor-Critic Examples: + >>> model = VAC(64,64) + >>> inputs = torch.randn(4, 64) + >>> outputs = model(inputs,'compute_actor_critic') + >>> outputs['value'] + tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) + >>> assert outputs['logit'].shape == torch.Size([4, 64]) + + """ + assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) + return getattr(self, mode)(inputs) + + def compute_actor(self, inputs: Dict, inference: bool = False) -> Dict: + r""" + Overview: + Execute parameter updates with ``'compute_actor'`` mode + Use encoded embedding tensor to predict output. + Arguments: + - inputs (:obj:`torch.Tensor`): + input data dict with keys ['obs'(with keys ['agent_state', 'global_state', 'action_mask']), + 'actor_prev_state'] + Returns: + - outputs (:obj:`Dict`): + Run with encoder RNN(optional) and head. + + ReturnsKeys: + - logit (:obj:`torch.Tensor`): Logit encoding tensor. + - actor_next_state: + - hidden_state + Shapes: + - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` + - actor_next_state: (B,) + - hidden_state: + + Examples: + >>> model = HAVAC( + agent_obs_shape=obs_dim, + global_obs_shape=global_obs_dim, + action_shape=action_dim, + use_lstm = True, + ) + >>> inputs = { + 'obs': { + 'agent_state': torch.randn(T, bs, obs_dim), + 'global_state': torch.randn(T, bs, global_obs_dim), + 'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) + }, + 'actor_prev_state': [None for _ in range(bs)], + } + >>> actor_outputs = model(inputs,'compute_actor') + >>> assert actor_outputs['logit'].shape == (T, bs, action_dim) + """ + x = inputs['obs']['agent_state'] + output = {} + if self.use_lstm: + rnn_actor_prev_state = inputs['actor_prev_state'] + if inference: + x = self.actor_encoder(x) + rnn_output = self.actor_rnn(x, rnn_actor_prev_state, inference) + x = rnn_output['output'] + x = self.actor_head(x) + output['next_state'] = rnn_output['next_state'] + # output: 'logit'/'next_state' + else: + assert len(x.shape) in [3, 5], x.shape + x = parallel_wrapper(self.actor_encoder)(x) # (T, B, N) + rnn_output = self.actor_rnn(x, rnn_actor_prev_state, inference) + x = rnn_output['output'] + x = parallel_wrapper(self.actor_head)(x) + output['actor_next_state'] = rnn_output['next_state'] + output['actor_hidden_state'] = rnn_output['hidden_state'] + # output: 'logit'/'actor_next_state'/'hidden_state' + else: + x = self.actor_encoder(x) + x = self.actor_head(x) + # output: 'logit' + + if self.action_space == 'discrete': + action_mask = inputs['obs']['action_mask'] + logit = x['logit'] + logit[action_mask == 0.0] = -99999999 + elif self.action_space == 'continuous': + logit = x + output['logit'] = logit + return output + + def compute_critic(self, inputs: Dict, inference: bool = False) -> Dict: + r""" + Overview: + Execute parameter updates with ``'compute_critic'`` mode + Use encoded embedding tensor to predict output. + Arguments: + - inputs (:obj:`Dict`): + input data dict with keys ['obs'(with keys ['agent_state', 'global_state', 'action_mask']), + 'critic_prev_state'(when you are using rnn)] + Returns: + - outputs (:obj:`Dict`): + Run with encoder [rnn] and head. + + Necessary Keys: + - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. + - logits + Shapes: + - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. + - logits + + Examples: + >>> model = HAVAC( + agent_obs_shape=obs_dim, + global_obs_shape=global_obs_dim, + action_shape=action_dim, + use_lstm = True, + ) + >>> inputs = { + 'obs': { + 'agent_state': torch.randn(T, bs, obs_dim), + 'global_state': torch.randn(T, bs, global_obs_dim), + 'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) + }, + 'critic_prev_state': [None for _ in range(bs)], + } + >>> critic_outputs = model(inputs,'compute_critic') + >>> assert critic_outputs['value'].shape == (T, bs)) + """ + global_obs = inputs['obs']['global_state'] + output = {} + if self.use_lstm: + rnn_critic_prev_state = inputs['critic_prev_state'] + if inference: + x = self.critic_encoder(global_obs) + rnn_output = self.critic_rnn(x, rnn_critic_prev_state, inference) + x = rnn_output['output'] + x = self.critic_head(x) + output['next_state'] = rnn_output['next_state'] + # output: 'value'/'next_state' + else: + assert len(global_obs.shape) in [3, 5], global_obs.shape + x = parallel_wrapper(self.critic_encoder)(global_obs) # (T, B, N) + rnn_output = self.critic_rnn(x, rnn_critic_prev_state, inference) + x = rnn_output['output'] + x = parallel_wrapper(self.critic_head)(x) + output['critic_next_state'] = rnn_output['next_state'] + output['critic_hidden_state'] = rnn_output['hidden_state'] + # output: 'value'/'critic_next_state'/'hidden_state' + else: + x = self.critic_encoder(global_obs) + x = self.critic_head(x) + # output: 'value' + output['value'] = x['pred'] + return output + + def compute_actor_critic(self, inputs: Dict, inference: bool = False) -> Dict: + r""" + Overview: + Execute parameter updates with ``'compute_actor_critic'`` mode + Use encoded embedding tensor to predict output. + Arguments: + - inputs (:dict): input data dict with keys + ['obs'(with keys ['agent_state', 'global_state', 'action_mask']), + 'actor_prev_state', 'critic_prev_state'(when you are using rnn)] + + Returns: + - outputs (:obj:`Dict`): + Run with encoder and head. + + ReturnsKeys: + - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. + - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. + Shapes: + - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` + - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. + + Examples: + >>> model = VAC(64,64) + >>> inputs = torch.randn(4, 64) + >>> outputs = model(inputs,'compute_actor_critic') + >>> outputs['value'] + tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) + >>> assert outputs['logit'].shape == torch.Size([4, 64]) + + + .. note:: + ``compute_actor_critic`` interface aims to save computation when shares encoder. + Returning the combination dictionry. + + """ + actor_output = self.compute_actor(inputs, inference) + critic_output = self.compute_critic(inputs, inference) + if self.use_lstm: + return { + 'logit': actor_output['logit'], + 'value': critic_output['value'], + 'actor_next_state': actor_output['actor_next_state'], + 'actor_hidden_state': actor_output['actor_hidden_state'], + 'critic_next_state': critic_output['critic_next_state'], + 'critic_hidden_state': critic_output['critic_hidden_state'], + } + else: + return { + 'logit': actor_output['logit'], + 'value': critic_output['value'], + } diff --git a/ding/model/template/tests/test_havac.py b/ding/model/template/tests/test_havac.py new file mode 100644 index 0000000000..42982ec5ae --- /dev/null +++ b/ding/model/template/tests/test_havac.py @@ -0,0 +1,103 @@ +import pytest +import torch +import random +from ding.torch_utils import is_differentiable +from ding.model.template import HAVAC + + +@pytest.mark.unittest +class TestHAVAC: + + def test_havac_rnn_actor(self): + # discrete+rnn + bs, T = 3, 8 + obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 + agent_num = 5 + data = { + 'obs': { + 'agent_state': torch.randn(T, bs, obs_dim), + 'global_state': torch.randn(T, bs, global_obs_dim), + 'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) + }, + 'actor_prev_state': [None for _ in range(bs)], + } + model = HAVAC( + agent_obs_shape=obs_dim, + global_obs_shape=global_obs_dim, + action_shape=action_dim, + agent_num=agent_num, + use_lstm=True, + ) + agent_idx = random.randint(0, agent_num - 1) + output = model(agent_idx, data, mode='compute_actor') + assert set(output.keys()) == set(['logit', 'actor_next_state', 'actor_hidden_state']) + assert output['logit'].shape == (T, bs, action_dim) + assert len(output['actor_next_state']) == bs + print(output['actor_next_state'][0]['h'].shape) + loss = output['logit'].sum() + is_differentiable(loss, model.agent_models[agent_idx].actor) + + def test_havac_rnn_critic(self): + # discrete+rnn + bs, T = 3, 8 + obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 + agent_num = 5 + data = { + 'obs': { + 'agent_state': torch.randn(T, bs, obs_dim), + 'global_state': torch.randn(T, bs, global_obs_dim), + 'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) + }, + 'critic_prev_state': [None for _ in range(bs)], + } + model = HAVAC( + agent_obs_shape=obs_dim, + global_obs_shape=global_obs_dim, + action_shape=action_dim, + agent_num=agent_num, + use_lstm=True, + ) + agent_idx = random.randint(0, agent_num - 1) + output = model(agent_idx, data, mode='compute_critic') + assert set(output.keys()) == set(['value', 'critic_next_state', 'critic_hidden_state']) + assert output['value'].shape == (T, bs) + assert len(output['critic_next_state']) == bs + print(output['critic_next_state'][0]['h'].shape) + loss = output['value'].sum() + is_differentiable(loss, model.agent_models[agent_idx].critic) + + def test_havac_rnn_actor_critic(self): + # discrete+rnn + bs, T = 3, 8 + obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 + agent_num = 5 + data = { + 'obs': { + 'agent_state': torch.randn(T, bs, obs_dim), + 'global_state': torch.randn(T, bs, global_obs_dim), + 'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) + }, + 'actor_prev_state': [None for _ in range(bs)], + 'critic_prev_state': [None for _ in range(bs)], + } + model = HAVAC( + agent_obs_shape=obs_dim, + global_obs_shape=global_obs_dim, + action_shape=action_dim, + agent_num=agent_num, + use_lstm=True, + ) + agent_idx = random.randint(0, agent_num - 1) + output = model(agent_idx, data, mode='compute_actor_critic') + assert set(output.keys()) == set( + ['logit', 'actor_next_state', 'actor_hidden_state', 'value', 'critic_next_state', 'critic_hidden_state'] + ) + assert output['logit'].shape == (T, bs, action_dim) + assert output['value'].shape == (T, bs) + loss = output['logit'].sum() + output['value'].sum() + is_differentiable(loss, model.agent_models[agent_idx]) + + +# test_havac_rnn_actor() +# test_havac_rnn_critic() +# test_havac_rnn_actor_critic() diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py index 25e8b67c4d..c85883a0af 100755 --- a/ding/policy/__init__.py +++ b/ding/policy/__init__.py @@ -55,3 +55,4 @@ # new-type policy from .ppof import PPOFPolicy from .prompt_pg import PromptPGPolicy +from .happo import HAPPOPolicy diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 2d5e3271dd..2e817ead4b 100644 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -51,6 +51,7 @@ from .edac import EDACPolicy from .prompt_pg import PromptPGPolicy from .plan_diffuser import PDPolicy +from .happo import HAPPOPolicy class EpsCommandModePolicy(CommandModePolicy): @@ -186,6 +187,11 @@ class PPOCommandModePolicy(PPOPolicy, DummyCommandModePolicy): pass +@POLICY_REGISTRY.register('happo_command') +class HAPPOCommandModePolicy(HAPPOPolicy, DummyCommandModePolicy): + pass + + @POLICY_REGISTRY.register('ppo_stdim_command') class PPOSTDIMCommandModePolicy(PPOSTDIMPolicy, DummyCommandModePolicy): pass diff --git a/ding/policy/happo.py b/ding/policy/happo.py new file mode 100644 index 0000000000..4cbd38324b --- /dev/null +++ b/ding/policy/happo.py @@ -0,0 +1,734 @@ +from typing import List, Dict, Any, Tuple, Union +from collections import namedtuple +import torch +import copy +import numpy as np +from torch.distributions import Independent, Normal + +from ding.torch_utils import Adam, to_device, to_dtype, unsqueeze, ContrastiveLoss +from ding.rl_utils import happo_data, happo_error, happo_policy_error, happo_policy_data, \ + v_nstep_td_data, v_nstep_td_error, get_train_sample, gae, gae_data, happo_error_continuous, \ + get_gae +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY, split_data_generator, RunningMeanStd +from ding.utils.data import default_collate, default_decollate +from .base_policy import Policy +from .common_utils import default_preprocess_learn + + +@POLICY_REGISTRY.register('happo') +class HAPPOPolicy(Policy): + """ + Overview: + Policy class of on policy version HAPPO algorithm. Paper link: https://arxiv.org/abs/2109.11251. + """ + config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). + type='happo', + # (bool) Whether to use cuda for network. + cuda=False, + # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used) + on_policy=True, + # (bool) Whether to use priority(priority sample, IS weight, update priority) + priority=False, + # (bool) Whether to use Importance Sampling Weight to correct biased update due to priority. + # If True, priority must be True. + priority_IS_weight=False, + # (bool) Whether to recompurete advantages in each iteration of on-policy PPO + recompute_adv=True, + # (str) Which kind of action space used in PPOPolicy, ['discrete', 'continuous', 'hybrid'] + action_space='discrete', + # (bool) Whether to use nstep return to calculate value target, otherwise, use return = adv + value + nstep_return=False, + # (bool) Whether to enable multi-agent training, i.e.: MAPPO + multi_agent=False, + # (bool) Whether to need policy data in process transition + transition_with_policy_data=True, + learn=dict( + epoch_per_collect=10, + batch_size=64, + learning_rate=3e-4, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) The loss weight of value network, policy network weight is set to 1 + value_weight=0.5, + # (float) The loss weight of entropy regularization, policy network weight is set to 1 + entropy_weight=0.0, + # (float) PPO clip ratio, defaults to 0.2 + clip_ratio=0.2, + # (bool) Whether to use advantage norm in a whole training batch + adv_norm=True, + value_norm=True, + ppo_param_init=True, + grad_clip_type='clip_norm', + grad_clip_value=0.5, + ignore_done=False, + ), + collect=dict( + # (int) Only one of [n_sample, n_episode] shoule be set + # n_sample=64, + # (int) Cut trajectories into pieces with length "unroll_len". + unroll_len=1, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) Reward's future discount factor, aka. gamma. + discount_factor=0.99, + # (float) GAE lambda factor for the balance of bias and variance(1-step td and mc) + gae_lambda=0.95, + ), + eval=dict(), + ) + + def _init_learn(self) -> None: + """ + Overview: + Initialize the learn mode of policy, including related attributes and modules. For HAPPO, it mainly \ + contains optimizer, algorithm-specific arguments such as loss weight, clip_ratio and recompute_adv. This \ + method also executes some special network initializations and prepares running mean/std monitor for value. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. + """ + self._priority = self._cfg.priority + self._priority_IS_weight = self._cfg.priority_IS_weight + assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in PPO" + + assert self._cfg.action_space in ["continuous", "discrete"] + self._action_space = self._cfg.action_space + if self._cfg.learn.ppo_param_init: + for n, m in self._model.named_modules(): + if isinstance(m, torch.nn.Linear): + torch.nn.init.orthogonal_(m.weight) + torch.nn.init.zeros_(m.bias) + if self._action_space in ['continuous']: + # init log sigma + for agent_id in range(self._cfg.agent_num): + # if hasattr(self._model.agent_models[agent_id].actor_head, 'log_sigma_param'): + # torch.nn.init.constant_(self._model.agent_models[agent_id].actor_head.log_sigma_param, 1) + # The above initialization step has been changed to reparameterizationHead. + for m in list(self._model.agent_models[agent_id].critic.modules()) + \ + list(self._model.agent_models[agent_id].actor.modules()): + if isinstance(m, torch.nn.Linear): + # orthogonal initialization + torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) + torch.nn.init.zeros_(m.bias) + # do last policy layer scaling, this will make initial actions have (close to) + # 0 mean and std, and will help boost performances, + # see https://arxiv.org/abs/2006.05990, Fig.24 for details + for m in self._model.agent_models[agent_id].actor.modules(): + if isinstance(m, torch.nn.Linear): + torch.nn.init.zeros_(m.bias) + m.weight.data.copy_(0.01 * m.weight.data) + + # Add the actor/critic parameters of each HAVACAgent in HAVAC to the parameter list of actor/critic_optimizer + actor_params = [] + critic_params = [] + for agent_idx in range(self._model.agent_num): + actor_params.append({'params': self._model.agent_models[agent_idx].actor.parameters()}) + critic_params.append({'params': self._model.agent_models[agent_idx].critic.parameters()}) + + self._actor_optimizer = Adam( + actor_params, + lr=self._cfg.learn.learning_rate, + grad_clip_type=self._cfg.learn.grad_clip_type, + clip_value=self._cfg.learn.grad_clip_value, + # eps = 1e-5, + ) + + self._critic_optimizer = Adam( + critic_params, + lr=self._cfg.learn.critic_learning_rate, + grad_clip_type=self._cfg.learn.grad_clip_type, + clip_value=self._cfg.learn.grad_clip_value, + # eps = 1e-5, + ) + + self._learn_model = model_wrap(self._model, wrapper_name='base') + # self._learn_model = model_wrap( + # self._model, + # wrapper_name='hidden_state', + # state_num=self._cfg.learn.batch_size, + # init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)] + # ) + + # Algorithm config + self._value_weight = self._cfg.learn.value_weight + self._entropy_weight = self._cfg.learn.entropy_weight + self._clip_ratio = self._cfg.learn.clip_ratio + self._adv_norm = self._cfg.learn.adv_norm + self._value_norm = self._cfg.learn.value_norm + if self._value_norm: + self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device) + self._gamma = self._cfg.collect.discount_factor + self._gae_lambda = self._cfg.collect.gae_lambda + self._recompute_adv = self._cfg.recompute_adv + # Main model + self._learn_model.reset() + + def prepocess_data_agent(self, data: Dict[str, Any]): + """ + Overview: + Preprocess data for agent dim. This function is used in learn mode. \ + It will be called recursively to process nested dict data. \ + It will transpose the data with shape (B, agent_num, ...) to (agent_num, B, ...). \ + Arguments: + - data (:obj:`dict`): Dict type data, where each element is the data of an agent of dict type. + Returns: + - ret (:obj:`dict`): Dict type data, where each element is the data of an agent of dict type. + """ + ret = {} + for key, value in data.items(): + if isinstance(value, dict): + ret[key] = self.prepocess_data_agent(value) + elif isinstance(value, torch.Tensor) and len(value.shape) > 1: + ret[key] = value.transpose(0, 1) + else: + ret[key] = value + return ret + + def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Overview: + Forward and backward function of learn mode. + Arguments: + - data (:obj:`dict`): List type data, where each element is the data of an agent of dict type. + Returns: + - info_dict (:obj:`Dict[str, Any]`): + Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \ + adv_abs_max, approx_kl, clipfrac + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, clipfrac, approx_kl. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including the latest \ + collected training samples for on-policy algorithms like HAPPO. For each element in list, the key of \ + dict is the name of data items and the value is the corresponding data. Usually, the value is \ + torch.Tensor or np.ndarray or there dict/list combinations. In the ``_forward_learn`` method, data \ + often need to first be stacked in the batch dimension by some utility functions such as \ + ``default_preprocess_learn``. \ + For HAPPO, each element in list is a dict containing at least the following keys: ``obs``, \ + ``action``, ``reward``, ``logit``, ``value``, ``done``. Sometimes, it also contains other keys \ + such as ``weight``. + Returns: + - return_infos (:obj:`List[Dict[str, Any]]`): The information list that indicated training result, each \ + training iteration contains append a information dict into the final list. The list will be precessed \ + and recorded in text log and tensorboard. The value of the dict must be python scalar or a list of \ + scalars. For the detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. tip:: + The training procedure of HAPPO is three for loops. The outermost loop trains each agent separately. \ + The middle loop trains all the collected training samples with ``epoch_per_collect`` epochs. The inner \ + loop splits all the data into different mini-batch with the length of ``batch_size``. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for HAPPOPolicy: ``ding.policy.tests.test_happo``. + """ + data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False) + all_data_len = data['obs']['agent_state'].shape[0] + # fator is the ratio of the old and new strategies of the first m-1 agents, initialized to 1. + # Each transition has its own factor. ref: http://arxiv.org/abs/2109.11251 + factor = torch.ones(all_data_len, 1) # (B, 1) + if self._cuda: + data = to_device(data, self._device) + factor = to_device(factor, self._device) + # process agent dim + data = self.prepocess_data_agent(data) + # ==================== + # PPO forward + # ==================== + return_infos = [] + self._learn_model.train() + + for agent_id in range(self._cfg.agent_num): + agent_data = {} + for key, value in data.items(): + if value is not None: + if type(value) is dict: + agent_data[key] = {k: v[agent_id] for k, v in value.items()} # not feasible for rnn + elif len(value.shape) > 1: + agent_data[key] = data[key][agent_id] + else: + agent_data[key] = data[key] + else: + agent_data[key] = data[key] + + # update factor + agent_data['factor'] = factor + # calculate old_logits of all data in buffer for later factor + inputs = { + 'obs': agent_data['obs'], + # 'actor_prev_state': agent_data['actor_prev_state'], + # 'critic_prev_state': agent_data['critic_prev_state'], + } + old_logits = self._learn_model.forward(agent_id, inputs, mode='compute_actor')['logit'] + + for epoch in range(self._cfg.learn.epoch_per_collect): + if self._recompute_adv: # calculate new value using the new updated value network + with torch.no_grad(): + inputs['obs'] = agent_data['obs'] + # value = self._learn_model.forward(agent_id, agent_data['obs'], mode='compute_critic')['value'] + value = self._learn_model.forward(agent_id, inputs, mode='compute_critic')['value'] + inputs['obs'] = agent_data['next_obs'] + next_value = self._learn_model.forward(agent_id, inputs, mode='compute_critic')['value'] + if self._value_norm: + value *= self._running_mean_std.std + next_value *= self._running_mean_std.std + + traj_flag = agent_data.get('traj_flag', None) # traj_flag indicates termination of trajectory + compute_adv_data = gae_data( + value, next_value, agent_data['reward'], agent_data['done'], traj_flag + ) + agent_data['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda) + + unnormalized_returns = value + agent_data['adv'] + + if self._value_norm: + agent_data['value'] = value / self._running_mean_std.std + agent_data['return'] = unnormalized_returns / self._running_mean_std.std + self._running_mean_std.update(unnormalized_returns.cpu().numpy()) + else: + agent_data['value'] = value + agent_data['return'] = unnormalized_returns + + else: # don't recompute adv + if self._value_norm: + unnormalized_return = agent_data['adv'] + agent_data['value'] * self._running_mean_std.std + agent_data['return'] = unnormalized_return / self._running_mean_std.std + self._running_mean_std.update(unnormalized_return.cpu().numpy()) + else: + agent_data['return'] = agent_data['adv'] + agent_data['value'] + + for batch in split_data_generator(agent_data, self._cfg.learn.batch_size, shuffle=True): + inputs = { + 'obs': batch['obs'], + # 'actor_prev_state': batch['actor_prev_state'], + # 'critic_prev_state': batch['critic_prev_state'], + } + output = self._learn_model.forward(agent_id, inputs, mode='compute_actor_critic') + adv = batch['adv'] + if self._adv_norm: + # Normalize advantage in a train_batch + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + + # Calculate happo error + if self._action_space == 'continuous': + happo_batch = happo_data( + output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv, + batch['return'], batch['weight'], batch['factor'] + ) + happo_loss, happo_info = happo_error_continuous(happo_batch, self._clip_ratio) + elif self._action_space == 'discrete': + happo_batch = happo_data( + output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv, + batch['return'], batch['weight'], batch['factor'] + ) + happo_loss, happo_info = happo_error(happo_batch, self._clip_ratio) + wv, we = self._value_weight, self._entropy_weight + total_loss = happo_loss.policy_loss + wv * happo_loss.value_loss - we * happo_loss.entropy_loss + + # actor update + # critic update + self._actor_optimizer.zero_grad() + self._critic_optimizer.zero_grad() + total_loss.backward() + self._actor_optimizer.step() + self._critic_optimizer.step() + + return_info = { + 'agent{}_cur_lr'.format(agent_id): self._actor_optimizer.defaults['lr'], + 'agent{}_total_loss'.format(agent_id): total_loss.item(), + 'agent{}_policy_loss'.format(agent_id): happo_loss.policy_loss.item(), + 'agent{}_value_loss'.format(agent_id): happo_loss.value_loss.item(), + 'agent{}_entropy_loss'.format(agent_id): happo_loss.entropy_loss.item(), + 'agent{}_adv_max'.format(agent_id): adv.max().item(), + 'agent{}_adv_mean'.format(agent_id): adv.mean().item(), + 'agent{}_value_mean'.format(agent_id): output['value'].mean().item(), + 'agent{}_value_max'.format(agent_id): output['value'].max().item(), + 'agent{}_approx_kl'.format(agent_id): happo_info.approx_kl, + 'agent{}_clipfrac'.format(agent_id): happo_info.clipfrac, + } + if self._action_space == 'continuous': + return_info.update( + { + 'agent{}_act'.format(agent_id): batch['action'].float().mean().item(), + 'agent{}_mu_mean'.format(agent_id): output['logit']['mu'].mean().item(), + 'agent{}_sigma_mean'.format(agent_id): output['logit']['sigma'].mean().item(), + } + ) + return_infos.append(return_info) + # calculate the factor + inputs = { + 'obs': agent_data['obs'], + # 'actor_prev_state': agent_data['actor_prev_state'], + } + new_logits = self._learn_model.forward(agent_id, inputs, mode='compute_actor')['logit'] + if self._cfg.action_space == 'discrete': + dist_new = torch.distributions.categorical.Categorical(logits=new_logits) + dist_old = torch.distributions.categorical.Categorical(logits=old_logits) + elif self._cfg.action_space == 'continuous': + dist_new = Normal(new_logits['mu'], new_logits['sigma']) + dist_old = Normal(old_logits['mu'], old_logits['sigma']) + logp_new = dist_new.log_prob(agent_data['action']) + logp_old = dist_old.log_prob(agent_data['action']) + if len(logp_new.shape) > 1: + # for logp with shape(B, action_shape), we need to calculate the product of all action dimensions. + factor = factor * torch.prod( + torch.exp(logp_new - logp_old), dim=-1 + ).reshape(all_data_len, 1).detach() # attention the shape + else: + # for logp with shape(B, ), directly calculate factor + factor = factor * torch.exp(logp_new - logp_old).reshape(all_data_len, 1).detach() + return return_infos + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode optimizer and model. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn mode. It contains the \ + state_dict of current policy network and optimizer. + """ + return { + 'model': self._learn_model.state_dict(), + 'actor_optimizer': self._actor_optimizer.state_dict(), + 'critic_optimizer': self._critic_optimizer.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict of learn mode optimizer and model. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn mode. It contains the state_dict \ + of current policy network and optimizer. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._actor_optimizer.load_state_dict(state_dict['actor_optimizer']) + self._critic_optimizer.load_state_dict(state_dict['critic_optimizer']) + + def _init_collect(self) -> None: + """ + Overview: + Initialize the collect mode of policy, including related attributes and modules. For HAPPO, it contains \ + the collect_model to balance the exploration and exploitation (e.g. the multinomial sample mechanism in \ + discrete action space), and other algorithm-specific arguments such as unroll_len and gae_lambda. + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. + + .. tip:: + Some variables need to initialize independently in different modes, such as gamma and gae_lambda in PPO. \ + This design is for the convenience of parallel execution of different policy modes. + """ + self._unroll_len = self._cfg.collect.unroll_len + assert self._cfg.action_space in ["continuous", "discrete"] + self._action_space = self._cfg.action_space + if self._action_space == 'continuous': + self._collect_model = model_wrap(self._model, wrapper_name='reparam_sample') + elif self._action_space == 'discrete': + self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample') + self._collect_model.reset() + self._gamma = self._cfg.collect.discount_factor + self._gae_lambda = self._cfg.collect.gae_lambda + self._recompute_adv = self._cfg.recompute_adv + + def _forward_collect(self, data: Dict[int, Any]) -> dict: + """ + Overview: + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data (action logit and value) for learn mode defined in ``self._process_transition`` \ + method. The key of the dict is the same as the input data, i.e. environment id. + + .. tip:: + If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \ + related data as extra keyword arguments of this method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for HAPPOPolicy: ``ding.policy.tests.test_happo``. + """ + data_id = list(data.keys()) + data = default_collate(list(data.values())) + if self._cuda: + data = to_device(data, self._device) + data = {k: v.transpose(0, 1) for k, v in data.items()} # not feasible for rnn + self._collect_model.eval() + with torch.no_grad(): + outputs = [] + for agent_id in range(self._cfg.agent_num): + # output = self._collect_model.forward(agent_id, data, mode='compute_actor_critic') + single_agent_obs = {k: v[agent_id] for k, v in data.items()} + input = { + 'obs': single_agent_obs, + } + output = self._collect_model.forward(agent_id, input, mode='compute_actor_critic') + outputs.append(output) + # transfer data from (M, B, N)->(B, M, N) + result = {} + for key in outputs[0].keys(): + if isinstance(outputs[0][key], dict): + subkeys = outputs[0][key].keys() + stacked_subvalues = {} + for subkey in subkeys: + stacked_subvalues[subkey] = \ + torch.stack([output[key][subkey] for output in outputs], dim=0).transpose(0, 1) + result[key] = stacked_subvalues + else: + # If Value is tensor, stack it directly + if isinstance(outputs[0][key], torch.Tensor): + result[key] = torch.stack([output[key] for output in outputs], dim=0).transpose(0, 1) + else: + # If it is not tensor, assume that it is a non-stackable data type \ + # (such as int, float, etc.), and directly retain the original value + result[key] = [output[key] for output in outputs] + output = result + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + """ + Overview: + Process and pack one timestep transition data into a dict, which can be directly used for training and \ + saved in replay buffer. For HAPPO, it contains obs, next_obs, action, reward, done, logit, value. + Arguments: + - obs (:obj:`torch.Tensor`): The env observation of current timestep. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. For PPO, it contains the state value, action and the logit of the action. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. + Returns: + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. + + .. note:: + ``next_obs`` is used to calculate nstep return when necessary, so we place in into transition by default. \ + You can delete this field to save memory occupancy if you do not need nstep return. + """ + transition = { + 'obs': obs, + 'next_obs': timestep.obs, + 'action': model_output['action'], + 'logit': model_output['logit'], + 'value': model_output['value'], + 'reward': timestep.reward, + 'done': timestep.done, + } + return transition + + def _get_train_sample(self, data: list) -> Union[None, List[Any]]: + """ + Overview: + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. In HAPPO, a train sample is a processed transition with new computed \ + ``traj_flag`` and ``adv`` field. This method is usually used in collectors to execute necessary \ + RL data preprocessing before training, which can help learner amortize revelant time consumption. \ + In addition, you can also implement this method as an identity function and do the data processing \ + in ``self._forward_learn`` method. + Arguments: + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ + as input transitions, but may contain more data for training, such as GAE advantage. + """ + data = to_device(data, self._device) + for transition in data: + transition['traj_flag'] = copy.deepcopy(transition['done']) + data[-1]['traj_flag'] = True + + if self._cfg.learn.ignore_done: + data[-1]['done'] = False + + if data[-1]['done']: + last_value = torch.zeros_like(data[-1]['value']) + else: + with torch.no_grad(): + last_values = [] + for agent_id in range(self._cfg.agent_num): + inputs = {'obs': {k: unsqueeze(v[agent_id], 0) for k, v in data[-1]['next_obs'].items()}} + last_value = self._collect_model.forward(agent_id, inputs, mode='compute_actor_critic')['value'] + last_values.append(last_value) + last_value = torch.cat(last_values) + if len(last_value.shape) == 2: # multi_agent case: + last_value = last_value.squeeze(0) + if self._value_norm: + last_value *= self._running_mean_std.std + for i in range(len(data)): + data[i]['value'] *= self._running_mean_std.std + data = get_gae( + data, + to_device(last_value, self._device), + gamma=self._gamma, + gae_lambda=self._gae_lambda, + cuda=False, + ) + if self._value_norm: + for i in range(len(data)): + data[i]['value'] /= self._running_mean_std.std + + # remove next_obs for save memory when not recompute adv + if not self._recompute_adv: + for i in range(len(data)): + data[i].pop('next_obs') + return get_train_sample(data, self._unroll_len) + + def _init_eval(self) -> None: + """ + Overview: + Initialize the eval mode of policy, including related attributes and modules. For PPO, it contains the \ + eval model to select optimial action (e.g. greedily select action with argmax mechanism in discrete action). + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. + """ + assert self._cfg.action_space in ["continuous", "discrete"] + self._action_space = self._cfg.action_space + if self._action_space == 'continuous': + self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample') + elif self._action_space == 'discrete': + self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') + self._eval_model.reset() + + def _forward_eval(self, data: dict) -> dict: + """ + Overview: + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. ``_forward_eval`` in HAPPO often uses deterministic sample method to \ + get actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \ + exploitation. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for HAPPOPolicy: ``ding.policy.tests.test_happo``. + """ + data_id = list(data.keys()) + data = default_collate(list(data.values())) + if self._cuda: + data = to_device(data, self._device) + # transfer data from (B, M, N)->(M, B, N) + data = {k: v.transpose(0, 1) for k, v in data.items()} # not feasible for rnn + self._eval_model.eval() + with torch.no_grad(): + outputs = [] + for agent_id in range(self._cfg.agent_num): + single_agent_obs = {k: v[agent_id] for k, v in data.items()} + input = { + 'obs': single_agent_obs, + } + output = self._eval_model.forward(agent_id, input, mode='compute_actor') + outputs.append(output) + output = self.revert_agent_data(outputs) + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + + .. note:: + The user can define and use customized network model but must obey the same inferface definition indicated \ + by import_names path. For example about HAPPO, its registered name is ``happo`` and the import_names is \ + ``ding.model.template.havac``. + """ + return 'havac', ['ding.model.template.havac'] + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ + variables = super()._monitor_vars_learn() + [ + 'policy_loss', + 'value_loss', + 'entropy_loss', + 'adv_max', + 'adv_mean', + 'approx_kl', + 'clipfrac', + 'value_max', + 'value_mean', + ] + if self._action_space == 'continuous': + variables += ['mu_mean', 'sigma_mean', 'sigma_grad', 'act'] + prefixes = [f'agent{i}_' for i in range(self._cfg.agent_num)] + variables = [prefix + var for prefix in prefixes for var in variables] + return variables + + def revert_agent_data(self, data: list): + """ + Overview: + Revert the data of each agent to the original data format. + Arguments: + - data (:obj:`list`): List type data, where each element is the data of an agent of dict type. + Returns: + - ret (:obj:`dict`): Dict type data, where each element is the data of an agent of dict type. + """ + ret = {} + # Traverse all keys of the first output + for key in data[0].keys(): + if isinstance(data[0][key], torch.Tensor): + # If the value corresponding to the current key is tensor, stack N tensors + stacked_tensor = torch.stack([output[key] for output in data], dim=0) + ret[key] = stacked_tensor.transpose(0, 1) + elif isinstance(data[0][key], dict): + # If the value corresponding to the current key is a dictionary, recursively \ + # call the function to process the contents inside the dictionary. + ret[key] = self.revert_agent_data([output[key] for output in data]) + return ret diff --git a/ding/rl_utils/__init__.py b/ding/rl_utils/__init__.py index 17dc2b5611..e86f6c1786 100644 --- a/ding/rl_utils/__init__.py +++ b/ding/rl_utils/__init__.py @@ -1,17 +1,19 @@ from .exploration import get_epsilon_greedy_fn, create_noise_generator -from .ppo import ppo_data, ppo_loss, ppo_info, ppo_policy_data, ppo_policy_error, ppo_value_data, ppo_value_error,\ +from .ppo import ppo_data, ppo_loss, ppo_info, ppo_policy_data, ppo_policy_error, ppo_value_data, ppo_value_error, \ ppo_error, ppo_error_continuous, ppo_policy_error_continuous, ppo_data_continuous, ppo_policy_data_continuous +from .happo import happo_data, happo_policy_data, happo_value_data, happo_loss, happo_policy_loss, happo_info, \ + happo_error, happo_policy_error, happo_value_error, happo_error_continuous, happo_policy_error_continuous from .ppg import ppg_data, ppg_joint_loss, ppg_joint_error from .gae import gae_data, gae from .a2c import a2c_data, a2c_error, a2c_error_continuous from .coma import coma_data, coma_error from .td import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, \ - q_1step_td_error, m_q_1step_td_data, m_q_1step_td_error, td_lambda_data, td_lambda_error,\ + q_1step_td_error, m_q_1step_td_data, m_q_1step_td_error, td_lambda_data, td_lambda_error, \ q_nstep_td_error_with_rescale, v_1step_td_data, v_1step_td_error, v_nstep_td_data, v_nstep_td_error, \ generalized_lambda_returns, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_error, dist_nstep_td_data, \ - nstep_return_data, nstep_return, iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error,\ + nstep_return_data, nstep_return, iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error, \ fqf_nstep_td_data, fqf_nstep_td_error, fqf_calculate_fraction_loss, evaluate_quantile_at_action, \ - q_nstep_sql_td_error, dqfd_nstep_td_error, dqfd_nstep_td_data, q_v_1step_td_error, q_v_1step_td_data,\ + q_nstep_sql_td_error, dqfd_nstep_td_error, dqfd_nstep_td_data, q_v_1step_td_error, q_v_1step_td_data, \ dqfd_nstep_td_error_with_rescale, discount_cumsum, bdq_nstep_td_error from .vtrace import vtrace_loss, compute_importance_weights from .upgo import upgo_loss diff --git a/ding/rl_utils/happo.py b/ding/rl_utils/happo.py new file mode 100644 index 0000000000..b37ddc7528 --- /dev/null +++ b/ding/rl_utils/happo.py @@ -0,0 +1,347 @@ +from collections import namedtuple +from typing import Optional, Tuple +import torch +import torch.nn as nn +from torch.distributions import Independent, Normal +from ding.hpc_rl import hpc_wrapper + +happo_value_data = namedtuple('happo_value_data', ['value_new', 'value_old', 'return_', 'weight']) +happo_loss = namedtuple('happo_loss', ['policy_loss', 'value_loss', 'entropy_loss']) +happo_policy_loss = namedtuple('happo_policy_loss', ['policy_loss', 'entropy_loss']) +happo_info = namedtuple('happo_info', ['approx_kl', 'clipfrac']) +happo_data = namedtuple( + 'happo_data', ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight', 'factor'] +) +happo_policy_data = namedtuple('happo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight', 'factor']) + + +def happo_error( + data: namedtuple, + clip_ratio: float = 0.2, + use_value_clip: bool = True, + dual_clip: Optional[float] = None, +) -> Tuple[namedtuple, namedtuple]: + """ + Overview: + Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip + Arguments: + - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data`` + - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 + - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy + - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ + defaults to 5.0, if you don't want to use it, set this parameter to None + Returns: + - happo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor + - happo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar + Shapes: + - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim + - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)` + - action (:obj:`torch.LongTensor`): :math:`(B, )` + - value_new (:obj:`torch.FloatTensor`): :math:`(B, )` + - value_old (:obj:`torch.FloatTensor`): :math:`(B, )` + - adv (:obj:`torch.FloatTensor`): :math:`(B, )` + - return (:obj:`torch.FloatTensor`): :math:`(B, )` + - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` + - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor + - value_loss (:obj:`torch.FloatTensor`): :math:`()` + - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` + Examples: + >>> action_dim = 4 + >>> data = happo_data( + >>> logit_new=torch.randn(3, action_dim), + >>> logit_old=torch.randn(3, action_dim), + >>> action=torch.randint(0, action_dim, (3,)), + >>> value_new=torch.randn(3), + >>> value_old=torch.randn(3), + >>> adv=torch.randn(3), + >>> return_=torch.randn(3), + >>> weight=torch.ones(3), + >>> factor=torch.ones(3, 1), + >>> ) + >>> loss, info = happo_error(data) + + .. note:: + + adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many + ways to calculate this mean and std, like among data buffer or train batch, so we don't couple + this part into happo_error, you can refer to our examples for different ways. + """ + assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( + dual_clip + ) + logit_new, logit_old, action, value_new, value_old, adv, return_, weight, factor = data + policy_data = happo_policy_data(logit_new, logit_old, action, adv, weight, factor) + policy_output, policy_info = happo_policy_error(policy_data, clip_ratio, dual_clip) + value_data = happo_value_data(value_new, value_old, return_, weight) + value_loss = happo_value_error(value_data, clip_ratio, use_value_clip) + + return happo_loss(policy_output.policy_loss, value_loss, policy_output.entropy_loss), policy_info + + +def happo_policy_error( + data: namedtuple, + clip_ratio: float = 0.2, + dual_clip: Optional[float] = None, +) -> Tuple[namedtuple, namedtuple]: + ''' + Overview: + Get PPO policy loss + Arguments: + - data (:obj:`namedtuple`): ppo input data with fieids shown in ``ppo_policy_data`` + - clip_ratio (:obj:`float`): clip value for ratio + - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ + defaults to 5.0, if you don't want to use it, set this parameter to None + Returns: + - happo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable \ + 0-dim tensor. + - happo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar + Shapes: + - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim + - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)` + - action (:obj:`torch.LongTensor`): :math:`(B, )` + - adv (:obj:`torch.FloatTensor`): :math:`(B, )` + - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` + - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor + - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` + Examples: + >>> action_dim = 4 + >>> data = ppo_policy_data( + >>> logit_new=torch.randn(3, action_dim), + >>> logit_old=torch.randn(3, action_dim), + >>> action=torch.randint(0, action_dim, (3,)), + >>> adv=torch.randn(3), + >>> weight=torch.ones(3), + >>> factor=torch.ones(3, 1), + >>> ) + >>> loss, info = happo_policy_error(data) + ''' + logit_new, logit_old, action, adv, weight, factor = data + if weight is None: + weight = torch.ones_like(adv) + dist_new = torch.distributions.categorical.Categorical(logits=logit_new) + dist_old = torch.distributions.categorical.Categorical(logits=logit_old) + logp_new = dist_new.log_prob(action) + logp_old = dist_old.log_prob(action) + dist_new_entropy = dist_new.entropy() + if dist_new_entropy.shape != weight.shape: + dist_new_entropy = dist_new.entropy().mean(dim=1) + entropy_loss = (dist_new_entropy * weight).mean() + # policy_loss + ratio = torch.exp(logp_new - logp_old) + if ratio.shape != adv.shape: + ratio = ratio.mean(dim=1) + surr1 = ratio * adv + surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv + # shape factor: (B,1) surr1: (B,) + clip1 = torch.min(surr1, surr2) * factor.squeeze(1) + if dual_clip is not None: + clip2 = torch.max(clip1, dual_clip * adv) + # only use dual_clip when adv < 0 + policy_loss = -(torch.where(adv < 0, clip2, clip1) * weight).mean() + else: + policy_loss = (-clip1 * weight).mean() + with torch.no_grad(): + approx_kl = (logp_old - logp_new).mean().item() + clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) + clipfrac = torch.as_tensor(clipped).float().mean().item() + return happo_policy_loss(policy_loss, entropy_loss), happo_info(approx_kl, clipfrac) + + +def happo_value_error( + data: namedtuple, + clip_ratio: float = 0.2, + use_value_clip: bool = True, +) -> torch.Tensor: + ''' + Overview: + Get PPO value loss + Arguments: + - data (:obj:`namedtuple`): ppo input data with fieids shown in ``happo_value_data`` + - clip_ratio (:obj:`float`): clip value for ratio + - use_value_clip (:obj:`bool`): whether use value clip + Returns: + - value_loss (:obj:`torch.FloatTensor`): the ppo value loss item, \ + all of them are the differentiable 0-dim tensor + Shapes: + - value_new (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size + - value_old (:obj:`torch.FloatTensor`): :math:`(B, )` + - return (:obj:`torch.FloatTensor`): :math:`(B, )` + - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` + - value_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor + Examples: + >>> action_dim = 4 + >>> data = happo_value_data( + >>> value_new=torch.randn(3), + >>> value_old=torch.randn(3), + >>> return_=torch.randn(3), + >>> weight=torch.ones(3), + >>> ) + >>> loss, info = happo_value_error(data) + ''' + value_new, value_old, return_, weight = data + if weight is None: + weight = torch.ones_like(value_old) + # value_loss + if use_value_clip: + value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio) + v1 = (return_ - value_new).pow(2) + v2 = (return_ - value_clip).pow(2) + value_loss = 0.5 * (torch.max(v1, v2) * weight).mean() + else: + value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean() + return value_loss + + +def happo_error_continuous( + data: namedtuple, + clip_ratio: float = 0.2, + use_value_clip: bool = True, + dual_clip: Optional[float] = None, +) -> Tuple[namedtuple, namedtuple]: + """ + Overview: + Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip + Arguments: + - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data`` + - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 + - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy + - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ + defaults to 5.0, if you don't want to use it, set this parameter to None + Returns: + - happo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor + - happo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar + Shapes: + - mu_sigma_new (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim + - mu_sigma_old (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim + - action (:obj:`torch.LongTensor`): :math:`(B, )` + - value_new (:obj:`torch.FloatTensor`): :math:`(B, )` + - value_old (:obj:`torch.FloatTensor`): :math:`(B, )` + - adv (:obj:`torch.FloatTensor`): :math:`(B, )` + - return (:obj:`torch.FloatTensor`): :math:`(B, )` + - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` + - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor + - value_loss (:obj:`torch.FloatTensor`): :math:`()` + - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` + Examples: + >>> action_dim = 4 + >>> data = ppo_data_continuous( + >>> mu_sigma_new= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), + >>> mu_sigma_old= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), + >>> action=torch.randn(3, action_dim), + >>> value_new=torch.randn(3), + >>> value_old=torch.randn(3), + >>> adv=torch.randn(3), + >>> return_=torch.randn(3), + >>> weight=torch.ones(3), + >>> ) + >>> loss, info = happo_error(data) + + .. note:: + + adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many + ways to calculate this mean and std, like among data buffer or train batch, so we don't couple + this part into happo_error, you can refer to our examples for different ways. + """ + assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( + dual_clip + ) + mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight, factor_batch = data + if weight is None: + weight = torch.ones_like(adv) + + dist_new = Normal(mu_sigma_new['mu'], mu_sigma_new['sigma']) + if len(mu_sigma_old['mu'].shape) == 1: + dist_old = Normal(mu_sigma_old['mu'].unsqueeze(-1), mu_sigma_old['sigma'].unsqueeze(-1)) + else: + dist_old = Normal(mu_sigma_old['mu'], mu_sigma_old['sigma']) + logp_new = dist_new.log_prob(action) + logp_old = dist_old.log_prob(action) + entropy_loss = (dist_new.entropy() * weight.unsqueeze(1)).mean() + + # policy_loss + ratio = torch.exp(logp_new - logp_old) + ratio = torch.prod(ratio, dim=-1) + surr1 = ratio * adv + surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv + if dual_clip is not None: + # shape factor: (B,1) surr1: (B,) + policy_loss = (-torch.max(factor_batch.squeeze(1) * torch.min(surr1, surr2), dual_clip * adv) * weight).mean() + else: + policy_loss = (-factor_batch.squeeze(1) * torch.min(surr1, surr2) * weight).mean() + with torch.no_grad(): + approx_kl = (logp_old - logp_new).mean().item() + clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) + clipfrac = torch.as_tensor(clipped).float().mean().item() + # value_loss + if use_value_clip: + value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio) + v1 = (return_ - value_new).pow(2) + v2 = (return_ - value_clip).pow(2) + value_loss = 0.5 * (torch.max(v1, v2) * weight).mean() + else: + value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean() + + return happo_loss(policy_loss, value_loss, entropy_loss), happo_info(approx_kl, clipfrac) + + +def happo_policy_error_continuous(data: namedtuple, + clip_ratio: float = 0.2, + dual_clip: Optional[float] = None) -> Tuple[namedtuple, namedtuple]: + """ + Overview: + Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip + Arguments: + - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data`` + - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 + - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ + defaults to 5.0, if you don't want to use it, set this parameter to None + Returns: + - happo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor + - happo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar + Shapes: + - mu_sigma_new (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim + - mu_sigma_old (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim + - action (:obj:`torch.LongTensor`): :math:`(B, )` + - adv (:obj:`torch.FloatTensor`): :math:`(B, )` + - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` + - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor + - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` + Examples: + >>> action_dim = 4 + >>> data = ppo_policy_data_continuous( + >>> mu_sigma_new=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), + >>> mu_sigma_old=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), + >>> action=torch.randn(3, action_dim), + >>> adv=torch.randn(3), + >>> weight=torch.ones(3), + >>> ) + >>> loss, info = happo_policy_error_continuous(data) + """ + assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( + dual_clip + ) + mu_sigma_new, mu_sigma_old, action, adv, weight = data + if weight is None: + weight = torch.ones_like(adv) + + dist_new = Independent(Normal(mu_sigma_new['mu'], mu_sigma_new['sigma']), 1) + if len(mu_sigma_old['mu'].shape) == 1: + dist_old = Independent(Normal(mu_sigma_old['mu'].unsqueeze(-1), mu_sigma_old['sigma'].unsqueeze(-1)), 1) + else: + dist_old = Independent(Normal(mu_sigma_old['mu'], mu_sigma_old['sigma']), 1) + logp_new = dist_new.log_prob(action) + logp_old = dist_old.log_prob(action) + entropy_loss = (dist_new.entropy() * weight).mean() + # policy_loss + ratio = torch.exp(logp_new - logp_old) + surr1 = ratio * adv + surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv + if dual_clip is not None: + policy_loss = (-torch.max(torch.min(surr1, surr2), dual_clip * adv) * weight).mean() + else: + policy_loss = (-torch.min(surr1, surr2) * weight).mean() + with torch.no_grad(): + approx_kl = (logp_old - logp_new).mean().item() + clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) + clipfrac = torch.as_tensor(clipped).float().mean().item() + return happo_policy_loss(policy_loss, entropy_loss), happo_info(approx_kl, clipfrac) diff --git a/ding/rl_utils/tests/test_happo.py b/ding/rl_utils/tests/test_happo.py new file mode 100644 index 0000000000..d82e5a37bc --- /dev/null +++ b/ding/rl_utils/tests/test_happo.py @@ -0,0 +1,71 @@ +import pytest +from itertools import product +import numpy as np +import torch + +from ding.rl_utils import happo_data, happo_error, happo_error_continuous +from ding.rl_utils.ppo import shape_fn_ppo + +use_value_clip_args = [True, False] +dual_clip_args = [None, 5.0] +random_weight = torch.rand(4) + 1 +weight_args = [None, random_weight] +factor_args = [torch.rand(4, 1)] +args = [item for item in product(*[use_value_clip_args, dual_clip_args, weight_args, factor_args])] + + +@pytest.mark.unittest +def test_shape_fn_ppo(): + data = happo_data(torch.randn(3, 5, 8), None, None, None, None, None, None, None, None) + shape1 = shape_fn_ppo([data], {}) + shape2 = shape_fn_ppo([], {'data': data}) + assert shape1 == shape2 == (3, 5, 8) + + +@pytest.mark.unittest +@pytest.mark.parametrize('use_value_clip, dual_clip, weight, factor', args) +def test_happo(use_value_clip, dual_clip, weight, factor): + B, N = 4, 32 + logit_new = torch.randn(B, N).requires_grad_(True) + logit_old = logit_new + torch.rand_like(logit_new) * 0.1 + action = torch.randint(0, N, size=(B, )) + value_new = torch.randn(B).requires_grad_(True) + value_old = value_new + torch.rand_like(value_new) * 0.1 + adv = torch.rand(B) + return_ = torch.randn(B) * 2 + data = happo_data(logit_new, logit_old, action, value_new, value_old, adv, return_, weight, factor) + loss, info = happo_error(data, use_value_clip=use_value_clip, dual_clip=dual_clip) + assert all([l.shape == tuple() for l in loss]) + assert all([np.isscalar(i) for i in info]) + assert logit_new.grad is None + assert value_new.grad is None + total_loss = sum(loss) + total_loss.backward() + assert isinstance(logit_new.grad, torch.Tensor) + assert isinstance(value_new.grad, torch.Tensor) + + +@pytest.mark.unittest +@pytest.mark.parametrize('use_value_clip, dual_clip, weight, factor', args) +def test_happo_error_continous(use_value_clip, dual_clip, weight, factor): + B, N = 4, 6 + mu_sigma_new = {'mu': torch.rand(B, N).requires_grad_(True), 'sigma': torch.rand(B, N).requires_grad_(True)} + mu_sigma_old = { + 'mu': mu_sigma_new['mu'] + torch.rand_like(mu_sigma_new['mu']) * 0.1, + 'sigma': mu_sigma_new['sigma'] + torch.rand_like(mu_sigma_new['sigma']) * 0.1 + } + action = torch.rand(B, N) + value_new = torch.randn(B).requires_grad_(True) + value_old = value_new + torch.rand_like(value_new) * 0.1 + adv = torch.rand(B) + return_ = torch.randn(B) * 2 + data = happo_data(mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight, factor) + loss, info = happo_error_continuous(data, use_value_clip=use_value_clip, dual_clip=dual_clip) + assert all([l.shape == tuple() for l in loss]) + assert all([np.isscalar(i) for i in info]) + assert mu_sigma_new['mu'].grad is None + assert value_new.grad is None + total_loss = sum(loss) + total_loss.backward() + assert isinstance(mu_sigma_new['mu'].grad, torch.Tensor) + assert isinstance(value_new.grad, torch.Tensor) diff --git a/dizoo/multiagent_mujoco/config/halfcheetah_happo_config.py b/dizoo/multiagent_mujoco/config/halfcheetah_happo_config.py new file mode 100644 index 0000000000..c849551d94 --- /dev/null +++ b/dizoo/multiagent_mujoco/config/halfcheetah_happo_config.py @@ -0,0 +1,83 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 +n_agent = 2 + +main_config = dict( + exp_name='HAPPO_result/debug/multi_mujoco_halfcheetah_2x3_happo', + env=dict( + scenario='HalfCheetah-v2', + agent_conf="2x3", + agent_obsk=2, + add_agent_id=False, + episode_limit=1000, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=8, + stop_value=6000, + ), + policy=dict( + cuda=True, + multi_agent=True, + agent_num=n_agent, + action_space='continuous', + model=dict( + action_space='continuous', + agent_num=n_agent, + agent_obs_shape=8, + global_obs_shape=17, + action_shape=3, + use_lstm=False, + ), + learn=dict( + epoch_per_collect=5, + # batch_size=3200, + batch_size=800, + learning_rate=5e-4, + critic_learning_rate=5e-4, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) The loss weight of value network, policy network weight is set to 1 + value_weight=0.5, + # (float) The loss weight of entropy regularization, policy network weight is set to 1 + # entropy_weight=0.001, + entropy_weight=0.001, + # (float) PPO clip ratio, defaults to 0.2 + clip_ratio=0.2, + # (bool) Whether to use advantage norm in a whole training batch + adv_norm=True, + value_norm=True, + ppo_param_init=True, + grad_clip_type='clip_norm', + grad_clip_value=3, + ignore_done=True, + # ignore_done=False, + ), + collect=dict( + n_sample=3200, + unroll_len=1, + env_num=collector_env_num, + ), + eval=dict( + env_num=evaluator_env_num, + evaluator=dict(eval_freq=1000, ), + ), + other=dict(), + ), +) +main_config = EasyDict(main_config) +create_config = dict( + env=dict( + type='mujoco_multi', + import_names=['dizoo.multiagent_mujoco.envs.multi_mujoco_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='happo'), +) +create_config = EasyDict(create_config) + +if __name__ == '__main__': + from ding.entry import serial_pipeline_onpolicy + serial_pipeline_onpolicy((main_config, create_config), seed=0, max_env_step=int(1e7)) diff --git a/dizoo/multiagent_mujoco/config/halfcheetah_mappo_config.py b/dizoo/multiagent_mujoco/config/halfcheetah_mappo_config.py new file mode 100644 index 0000000000..b6db3feea7 --- /dev/null +++ b/dizoo/multiagent_mujoco/config/halfcheetah_mappo_config.py @@ -0,0 +1,80 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 + +main_config = dict( + exp_name='HAPPO_result/multi_mujoco_halfcheetah_2x3_mappo', + env=dict( + scenario='HalfCheetah-v2', + agent_conf="2x3", + agent_obsk=2, + add_agent_id=False, + episode_limit=1000, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=8, + stop_value=6000, + ), + policy=dict( + cuda=True, + multi_agent=True, + action_space='continuous', + model=dict( + # (int) agent_num: The number of the agent. + # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2. + agent_num=2, + # (int) obs_shape: The shapeension of observation of each agent. + # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404. + # (int) global_obs_shape: The shapeension of global observation. + # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342. + agent_obs_shape=8, + #global_obs_shape=216, + global_obs_shape=17, + # (int) action_shape: The number of action which each agent can take. + # action_shape= the number of common action (6) + the number of enemies. + # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64). + action_shape=3, + # (List[int]) The size of hidden layer + # hidden_size_list=[64], + action_space='continuous' + ), + # used in state_num of hidden_state + learn=dict( + epoch_per_collect=5, + batch_size=800, + learning_rate=5e-4, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) The loss weight of value network, policy network weight is set to 1 + value_weight=0.5, + # (float) The loss weight of entropy regularization, policy network weight is set to 1 + entropy_weight=0.001, + # (float) PPO clip ratio, defaults to 0.2 + clip_ratio=0.2, + # (bool) Whether to use advantage norm in a whole training batch + adv_norm=True, + value_norm=True, + ppo_param_init=True, + grad_clip_type='clip_norm', + grad_clip_value=5, + ), + collect=dict(env_num=collector_env_num, n_sample=3200), + eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=1000, )), + ), +) +main_config = EasyDict(main_config) +create_config = dict( + env=dict( + type='mujoco_multi', + import_names=['dizoo.multiagent_mujoco.envs.multi_mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='ppo'), +) +create_config = EasyDict(create_config) + +if __name__ == '__main__': + from ding.entry import serial_pipeline_onpolicy + serial_pipeline_onpolicy((main_config, create_config), seed=0, max_env_step=int(1e7)) diff --git a/dizoo/multiagent_mujoco/config/walker2d_happo_config.py b/dizoo/multiagent_mujoco/config/walker2d_happo_config.py new file mode 100644 index 0000000000..a947a25589 --- /dev/null +++ b/dizoo/multiagent_mujoco/config/walker2d_happo_config.py @@ -0,0 +1,91 @@ +from easydict import EasyDict +import os +collector_env_num = 8 +evaluator_env_num = 8 +n_agent = 2 + +main_config = dict( + exp_name='HAPPO_result/debug/multi_mujoco_walker_2x3_happo', + env=dict( + scenario='Walker2d-v2', + agent_conf="2x3", + agent_obsk=2, + add_agent_id=False, + episode_limit=1000, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=8, + stop_value=6000, + ), + policy=dict( + cuda=True, + multi_agent=True, + agent_num=n_agent, + action_space='continuous', + model=dict( + action_space='continuous', + agent_num=n_agent, + agent_obs_shape=8, + global_obs_shape=17, + action_shape=3, + use_lstm=False, + ), + learn=dict( + epoch_per_collect=5, + # batch_size=3200, + # batch_size=800, + batch_size=320, + # batch_size=100, + learning_rate=5e-4, + critic_learning_rate=5e-3, + # learning_rate=3e-3, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) The loss weight of value network, policy network weight is set to 1 + # value_weight=0.5, + value_weight=1, + # (float) The loss weight of entropy regularization, policy network weight is set to 1 + # entropy_weight=0.001, + entropy_weight=0.003, + # entropy_weight=0.005, + # entropy_weight=0.01, + # (float) PPO clip ratio, defaults to 0.2 + clip_ratio=0.2, + # (bool) Whether to use advantage norm in a whole training batch + adv_norm=True, + value_norm=True, + ppo_param_init=True, + grad_clip_type='clip_norm', + # grad_clip_value=5, + grad_clip_value=10, + # ignore_done=True, + ignore_done=False, + ), + collect=dict( + n_sample=3200, + # n_sample=4000, + unroll_len=1, + env_num=collector_env_num, + ), + eval=dict( + env_num=evaluator_env_num, + evaluator=dict(eval_freq=1000, ), + ), + other=dict(), + ), +) +main_config = EasyDict(main_config) +create_config = dict( + env=dict( + type='mujoco_multi', + import_names=['dizoo.multiagent_mujoco.envs.multi_mujoco_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='happo'), +) +create_config = EasyDict(create_config) + +if __name__ == '__main__': + from ding.entry import serial_pipeline_onpolicy + serial_pipeline_onpolicy((main_config, create_config), seed=0, max_env_step=int(1e7)) diff --git a/dizoo/petting_zoo/config/ptz_simple_spread_happo_config.py b/dizoo/petting_zoo/config/ptz_simple_spread_happo_config.py new file mode 100644 index 0000000000..d1ff088326 --- /dev/null +++ b/dizoo/petting_zoo/config/ptz_simple_spread_happo_config.py @@ -0,0 +1,88 @@ +from easydict import EasyDict + +n_agent = 3 +n_landmark = n_agent +collector_env_num = 8 +evaluator_env_num = 8 +main_config = dict( + exp_name='ptz_simple_spread_happo_seed0', + env=dict( + env_family='mpe', + env_id='simple_spread_v2', + n_agent=n_agent, + n_landmark=n_landmark, + max_cycles=25, + agent_obs_only=False, + agent_specific_global_state=True, + continuous_actions=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + stop_value=0, + ), + policy=dict( + cuda=True, + multi_agent=True, + agent_num=n_agent, + action_space='discrete', + model=dict( + action_space='discrete', + agent_num=n_agent, + agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2, + global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + + n_landmark * 2 + n_agent * (n_agent - 1) * 2, + action_shape=5, + use_lstm=False, + ), + learn=dict( + multi_gpu=False, + epoch_per_collect=5, + batch_size=3200, + learning_rate=5e-4, + critic_learning_rate=5e-4, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) The loss weight of value network, policy network weight is set to 1 + value_weight=0.5, + # (float) The loss weight of entropy regularization, policy network weight is set to 1 + entropy_weight=0.01, + # (float) PPO clip ratio, defaults to 0.2 + clip_ratio=0.2, + # (bool) Whether to use advantage norm in a whole training batch + adv_norm=False, + value_norm=True, + ppo_param_init=True, + grad_clip_type='clip_norm', + grad_clip_value=10, + ignore_done=False, + ), + collect=dict( + n_sample=3200, + unroll_len=1, + env_num=collector_env_num, + ), + eval=dict( + env_num=evaluator_env_num, + evaluator=dict(eval_freq=50, ), + ), + other=dict(), + ), +) +main_config = EasyDict(main_config) +create_config = dict( + env=dict( + import_names=['dizoo.petting_zoo.envs.petting_zoo_simple_spread_env'], + type='petting_zoo', + ), + env_manager=dict(type='base'), + policy=dict(type='happo'), +) +create_config = EasyDict(create_config) +ptz_simple_spread_happo_config = main_config +ptz_simple_spread_happo_create_config = create_config + +if __name__ == '__main__': + # or you can enter `ding -m serial_onpolicy -c ptz_simple_spread_happo_config.py -s 0` + from ding.entry import serial_pipeline_onpolicy + serial_pipeline_onpolicy((main_config, create_config), seed=0)